shape.cpp 16.2 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
4
#include <migraphx/serialize.hpp>
5
#include <migraphx/permutation.hpp>
6
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
7
8
9
#include <numeric>
#include <algorithm>
#include <functional>
10
#include <unordered_map>
Paul's avatar
Paul committed
11
#include <iostream>
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
namespace migraphx {
Paul's avatar
Paul committed
14
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
15

Paul's avatar
Paul committed
16
17
struct shape_impl
{
Paul's avatar
Paul committed
18
19
    static std::shared_ptr<shape_impl> default_shape()
    {
20
        static const std::shared_ptr<shape_impl> result = std::make_shared<shape_impl>();
Paul's avatar
Paul committed
21
22
23
        return result;
    }

Paul Fultz II's avatar
Paul Fultz II committed
24
    shape_impl() : m_type(shape::float_type) {}
Paul's avatar
Paul committed
25

Paul Fultz II's avatar
Paul Fultz II committed
26
27
28
29
    shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true)
    {
        assert(t != shape::tuple_type);
    }
Paul's avatar
Paul committed
30
31
32
    shape_impl(shape::type_t t, std::vector<std::size_t> l)
        : m_type(t), m_lens(std::move(l)), m_standard(true)
    {
Paul Fultz II's avatar
Paul Fultz II committed
33
        assert(t != shape::tuple_type);
Paul's avatar
Paul committed
34
35
36
37
38
39
        this->calculate_strides();
        assert(m_lens.size() == m_strides.size());
    }
    shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
        : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
    {
Paul Fultz II's avatar
Paul Fultz II committed
40
        assert(t != shape::tuple_type);
Paul's avatar
Paul committed
41
        assert(m_lens.size() == m_strides.size());
Khalique's avatar
Khalique committed
42
43
        // assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
        //        "At least one stride must be non-zero");
44
45
        m_standard = this->elements() == this->element_space() and
                     std::is_sorted(m_strides.rbegin(), m_strides.rend());
Paul's avatar
Paul committed
46
    }
Paul Fultz II's avatar
Paul Fultz II committed
47

charlie's avatar
charlie committed
48
    shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims)
49
        : m_type(t), m_dyn_dims(std::move(dims))
50
51
52
    {
    }

53
54
    shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}

Paul's avatar
Paul committed
55
    shape::type_t m_type;
Paul Fultz II's avatar
Paul Fultz II committed
56
57
58
59
    std::vector<std::size_t> m_lens    = {};
    std::vector<std::size_t> m_strides = {};
    std::vector<shape> m_shapes        = {};
    bool m_standard                    = false;
60

61
    std::vector<shape::dynamic_dimension> m_dyn_dims = {};
Paul's avatar
Paul committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    void calculate_strides()
    {
        m_strides.clear();
        m_strides.resize(m_lens.size(), 0);
        if(m_strides.empty())
            return;
        m_strides.back() = 1;
        std::partial_sum(m_lens.rbegin(),
                         m_lens.rend() - 1,
                         m_strides.rbegin() + 1,
                         std::multiplies<std::size_t>());
    }

    std::size_t element_space() const
    {
78
        if(not m_dyn_dims.empty())
79
        {
80
            auto maxes = max_lens();
charlie's avatar
charlie committed
81
            return std::accumulate(maxes.begin(), maxes.end(), std::size_t{1}, std::multiplies<>());
82
83
        }

Paul's avatar
Paul committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        assert(m_lens.size() == m_strides.size());
        if(m_lens.empty())
            return 0;
        return std::inner_product(m_lens.begin(),
                                  m_lens.end(),
                                  m_strides.begin(),
                                  std::size_t{0},
                                  std::plus<std::size_t>{},
                                  [](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
               1;
    }

    std::size_t elements() const
    {
98
        if(not m_dyn_dims.empty())
99
100
101
102
        {
            MIGRAPHX_THROW("SHAPE: elements() called on dynamic shape");
        }

Paul's avatar
Paul committed
103
104
105
106
107
108
        assert(m_lens.size() == m_strides.size());
        if(m_lens.empty())
            return 0;
        return std::accumulate(
            m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
    }
109

110
111
112
    std::vector<std::size_t> min_lens() const
    {
        std::vector<std::size_t> ret(m_dyn_dims.size());
charlie's avatar
charlie committed
113
114
115
116
        std::transform(m_dyn_dims.cbegin(),
                       m_dyn_dims.cend(),
                       ret.begin(),
                       [](shape::dynamic_dimension x) { return x.min; });
117
118
119
120
121
122
        return ret;
    }

    std::vector<std::size_t> max_lens() const
    {
        std::vector<std::size_t> ret(m_dyn_dims.size());
charlie's avatar
charlie committed
123
124
125
126
        std::transform(m_dyn_dims.cbegin(),
                       m_dyn_dims.cend(),
                       ret.begin(),
                       [](shape::dynamic_dimension x) { return x.max; });
127
128
129
130
131
132
        return ret;
    }

    std::vector<std::size_t> opt_lens() const
    {
        std::vector<std::size_t> ret(m_dyn_dims.size());
charlie's avatar
charlie committed
133
134
135
136
        std::transform(m_dyn_dims.cbegin(),
                       m_dyn_dims.cend(),
                       ret.begin(),
                       [](shape::dynamic_dimension x) { return x.opt; });
137
138
139
        return ret;
    }

140
    std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
Paul's avatar
Paul committed
141
};
Paul's avatar
Paul committed
142

143
144
145
146
const std::vector<shape::type_t>& shape::types()
{
    static const std::vector<shape::type_t> result = {
#define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x,
Paul Fultz II's avatar
Paul Fultz II committed
147
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type};
148
149
150
    return result;
}

151
152
153
154
std::string shape::name(shape::type_t t)
{
    switch(t)
    {
Paul Fultz II's avatar
Paul Fultz II committed
155
    case tuple_type: return "tuple_type";
156
157
158
159
160
161
162
163
164
165
166
#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \
    case x: return #x;
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE
    }
    MIGRAPHX_THROW("Invalid type");
}
std::string shape::cpp_type(shape::type_t t)
{
    switch(t)
    {
Paul Fultz II's avatar
Paul Fultz II committed
167
    case tuple_type: MIGRAPHX_THROW("No C++ type for tuple");
168
169
170
171
172
173
174
175
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
    case x: return #t;
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE
    }
    MIGRAPHX_THROW("Invalid type");
}

Paul's avatar
Paul committed
176
177
178
shape::shape() : impl(shape_impl::default_shape()) {}

shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
Paul's avatar
Paul committed
179
shape::shape(type_t t, std::vector<std::size_t> l)
Paul's avatar
Paul committed
180
    : impl(std::make_shared<shape_impl>(t, std::move(l)))
Paul's avatar
Paul committed
181
182
183
{
}
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
Paul's avatar
Paul committed
184
    : impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
Paul's avatar
Paul committed
185
186
187
{
}

188
189
190
191
192
shape::shape(type_t t, std::initializer_list<std::size_t> d)
    : shape::shape(t, std::vector<std::size_t>{d.begin(), d.end()})
{
}

charlie's avatar
charlie committed
193
shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
194
195
196
    : impl(std::make_shared<shape_impl>(t, std::move(dims)))
{
}
Paul Fultz II's avatar
Paul Fultz II committed
197

198
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
199

200
201
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}

202
203
204
205
206
207
208
209
210
211
shape shape::from_permutation(type_t t,
                              const std::vector<std::size_t>& l,
                              const std::vector<int64_t>& perm)
{
    auto new_lens = reorder_dims(l, perm);
    shape result  = reorder_shape({t, new_lens}, invert_permutation(perm));
    assert(result.lens() == l);
    return result;
}

Paul's avatar
Paul committed
212
shape::type_t shape::type() const { return impl->m_type; }
213

Paul's avatar
Paul committed
214
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
215

Paul's avatar
Paul committed
216
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
217

Paul's avatar
Paul committed
218
std::size_t shape::elements() const { return impl->elements(); }
219

Paul's avatar
Paul committed
220
221
std::size_t shape::bytes() const
{
Paul Fultz II's avatar
Paul Fultz II committed
222
223
224
225
226
227
228
229
230
231
232
233
234
    if(this->sub_shapes().empty())
    {
        std::size_t n = 0;
        this->visit_type([&](auto as) { n = as.size(); });
        return n * this->element_space();
    }
    else
    {
        return std::accumulate(this->sub_shapes().begin(),
                               this->sub_shapes().end(),
                               std::size_t{0},
                               [&](auto x, auto y) { return x + y.bytes(); });
    }
Paul's avatar
Paul committed
235
}
236

Scott Thornton's avatar
Scott Thornton committed
237
238
239
std::size_t shape::type_size() const
{
    std::size_t n = 0;
Paul Fultz II's avatar
Paul Fultz II committed
240
241
    if(this->sub_shapes().empty())
        this->visit_type([&](auto as) { n = as.size(); });
Scott Thornton's avatar
Scott Thornton committed
242
243
    return n;
}
244

Paul's avatar
Paul committed
245
246
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
247
248
249
250
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
    }
Paul's avatar
Paul committed
251
252
253
254
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
charlie's avatar
charlie committed
255

Paul's avatar
Paul committed
256
257
std::size_t shape::index(const std::vector<std::size_t>& l) const
{
258
259
260
261
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
    }
Paul's avatar
Paul committed
262
263
264
265
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
charlie's avatar
charlie committed
266

Paul's avatar
Paul committed
267
268
std::size_t shape::index(std::size_t i) const
{
269
270
271
272
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
    }
Paul's avatar
Paul committed
273
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
274
    if(this->standard())
Paul's avatar
Paul committed
275
276
        return i;
    else
Paul's avatar
Paul committed
277
    {
Paul's avatar
Paul committed
278
        std::size_t s      = 1;
Paul's avatar
Paul committed
279
        std::size_t result = 0;
Paul's avatar
Paul committed
280
        for(std::size_t j = 0; j < this->lens().size(); j++)
Paul's avatar
Paul committed
281
        {
Paul's avatar
Paul committed
282
            const std::size_t k      = this->lens().size() - j - 1;
Paul's avatar
Paul committed
283
            const std::size_t stride = this->strides()[k];
Paul's avatar
Paul committed
284
285
            const std::size_t len    = this->lens()[k];
            const std::size_t idx    = (i % (s * len)) / s;
Paul's avatar
Paul committed
286
287
288
289
290
            result += stride * idx;
            s *= len;
        }
        return result;
    }
Paul's avatar
Paul committed
291
}
292
293
294
295
296
297

std::vector<std::size_t> shape::multi(std::size_t i) const
{
    assert(this->standard());

    std::vector<std::size_t> indices(lens().size());
298
299
300
301
302
303
304
305
306
307
    multi_copy(i, indices.data(), indices.data() + lens().size());

    return indices;
}

void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const
{
    assert(this->standard());
    (void)end;
    assert(lens().size() <= (end - start));
308
    std::transform(strides().begin(),
Shucai Xiao's avatar
Shucai Xiao committed
309
310
                   strides().end(),
                   lens().begin(),
311
                   start,
Shucai Xiao's avatar
Shucai Xiao committed
312
313
314
315
                   [&](std::size_t stride, std::size_t len) {
                       assert(len > 0 and stride > 0);
                       return (i / stride) % len;
                   });
316
317
}

Paul Fultz II's avatar
Paul Fultz II committed
318
319
bool shape::packed() const
{
320
321
322
323
    if(this->dynamic())
    {
        return false;
    }
Paul Fultz II's avatar
Paul Fultz II committed
324
325
    return this->sub_shapes().empty() and this->elements() == this->element_space();
}
Paul's avatar
Paul committed
326

Paul's avatar
Paul committed
327
328
bool shape::transposed() const
{
329
330
331
332
    if(this->dynamic())
    {
        return false;
    }
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    if(this->broadcasted())
    {
        // TODO: Use a filter_iterator instead
        std::vector<std::size_t> s;
        s.reserve(this->strides().size());
        std::copy_if(this->strides().begin(),
                     this->strides().end(),
                     std::back_inserter(s),
                     [](std::size_t x) { return x != 0; });
        return not std::is_sorted(s.rbegin(), s.rend());
    }
    else
    {
        return not std::is_sorted(this->strides().rbegin(), this->strides().rend());
    }
Paul's avatar
Paul committed
348
}
Paul's avatar
Paul committed
349
350
351

bool shape::broadcasted() const
{
352
353
354
355
    if(this->dynamic())
    {
        return false;
    }
Paul's avatar
Paul committed
356
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
357
358
359
360
    return std::accumulate(this->strides().begin(),
                           this->strides().end(),
                           std::size_t{1},
                           std::multiplies<std::size_t>()) == 0;
Paul's avatar
Paul committed
361
362
}

Khalique's avatar
Khalique committed
363
364
bool shape::scalar() const
{
365
366
367
368
    if(this->dynamic())
    {
        return false;
    }
Khalique's avatar
Khalique committed
369
370
    assert(this->lens().size() == this->strides().size());
    // if any stride > 0, then accumulate will return false
Paul Fultz II's avatar
Paul Fultz II committed
371
372
    return this->sub_shapes().empty() and
           std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
Khalique's avatar
Khalique committed
373
374
}

Paul's avatar
Paul committed
375
bool shape::standard() const { return impl->m_standard; }
Paul's avatar
Paul committed
376

377
378
379
380
381
382
383
384
shape shape::normalize_standard() const
{
    if(this->standard())
        return {this->type(), this->lens()};
    else
        return *this;
}

385
386
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
{
387
388
389
390
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
    }
391
392
393
394
395
396
397
    assert(l.size() == this->lens().size());
    auto perm = find_permutation(*this);
    return shape::from_permutation(t, l, perm);
}

shape shape::with_lens(const std::vector<std::size_t>& l) const
{
398
399
400
401
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
    }
402
403
404
    return this->with_lens(this->type(), l);
}

405
406
407
408
409
410
411
shape shape::with_type(type_t t) const
{
    auto c    = impl->copy();
    c->m_type = t;
    return {c};
}

Paul's avatar
Paul committed
412
std::size_t shape::element_space() const { return impl->element_space(); }
Paul's avatar
Paul committed
413

414
std::string shape::type_string() const { return name(this->type()); }
Paul's avatar
Paul committed
415

416
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
charlie's avatar
charlie committed
417
418
419

const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }

420
std::vector<std::size_t> shape::min_lens() const
charlie's avatar
charlie committed
421
{
charlie's avatar
charlie committed
422
    if(not this->dynamic())
charlie's avatar
charlie committed
423
    {
424
        return this->lens();
charlie's avatar
charlie committed
425
    }
charlie's avatar
charlie committed
426
427
    return impl->min_lens();
    ;
charlie's avatar
charlie committed
428
429
}

430
std::vector<std::size_t> shape::max_lens() const
charlie's avatar
charlie committed
431
{
charlie's avatar
charlie committed
432
    if(not this->dynamic())
charlie's avatar
charlie committed
433
    {
434
        return this->lens();
charlie's avatar
charlie committed
435
    }
charlie's avatar
charlie committed
436
437
    return impl->max_lens();
    ;
charlie's avatar
charlie committed
438
439
}

440
std::vector<std::size_t> shape::opt_lens() const
charlie's avatar
charlie committed
441
{
charlie's avatar
charlie committed
442
    if(not this->dynamic())
charlie's avatar
charlie committed
443
    {
444
        return this->lens();
charlie's avatar
charlie committed
445
    }
charlie's avatar
charlie committed
446
447
    return impl->opt_lens();
    ;
charlie's avatar
charlie committed
448
449
}

Paul's avatar
Paul committed
450
451
bool operator==(const shape& x, const shape& y)
{
452
453
454
455
456
457
458
459
    if(x.dynamic() and y.dynamic())
    {
        return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and
                                    x.sub_shapes() == y.sub_shapes());
    }
    return x.impl == y.impl or
           (x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and
            x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
Paul's avatar
Paul committed
460
}
461

Paul's avatar
Paul committed
462
bool operator!=(const shape& x, const shape& y) { return !(x == y); }
Paul's avatar
Paul committed
463

Paul's avatar
Paul committed
464
465
std::ostream& operator<<(std::ostream& os, const shape& x)
{
Paul Fultz II's avatar
Paul Fultz II committed
466
467
    if(x.sub_shapes().empty())
    {
468
469
470
471
472
473
474
475
476
477
478
479
        if(x.dynamic())
        {
            os << "dynamic, ";
            os << x.type_string() << ", ";
            os << "{" << to_string_range(x.dyn_dims()) << "}";
        }
        else
        {
            os << x.type_string() << ", ";
            os << "{" << to_string_range(x.lens()) << "}, ";
            os << "{" << to_string_range(x.strides()) << "}";
        }
Paul Fultz II's avatar
Paul Fultz II committed
480
481
482
483
484
    }
    else
    {
        os << "[" << to_string_range(x.sub_shapes()) << "]";
    }
Paul's avatar
Paul committed
485
486
487
    return os;
}

488
489
shape::type_t shape::parse_type(const std::string& s)
{
490
    static const std::unordered_map<std::string, shape::type_t> m = {
491
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x},
Paul Fultz II's avatar
Paul Fultz II committed
492
493
494
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type",
                                                                            tuple_type},
        {"tuple", tuple_type}};
495
496
497
    return m.at(s);
}

Paul Fultz II's avatar
Paul Fultz II committed
498
499
const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }

500
501
502
void migraphx_to_value(value& v, const shape& s)
{
    value result;
charlie's avatar
charlie committed
503
    result["type"] = migraphx::to_value(s.type_string());
504
505
    if(s.dynamic())
    {
charlie's avatar
charlie committed
506
507
508
509
510
        result["dynamic"]    = migraphx::to_value(s.dynamic());
        result["min_lens"]   = migraphx::to_value(s.min_lens());
        result["max_lens"]   = migraphx::to_value(s.max_lens());
        result["opt_lens"]   = migraphx::to_value(s.opt_lens());
        result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
511
512
513
514
515
516
517
518
    }
    else
    {
        result["lens"]       = migraphx::to_value(s.lens());
        result["strides"]    = migraphx::to_value(s.strides());
        result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
    }
    v = result;
519
}
520

521
522
void migraphx_from_value(const value& v, shape& s)
{
Paul Fultz II's avatar
Paul Fultz II committed
523
524
525
526
527
528
529
    auto t = v.at("type").get_string();
    if(t == "tuple_type")
    {
        s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))};
    }
    else
    {
charlie's avatar
charlie committed
530
        if(v.contains("dynamic"))
531
        {
532
533
534
            auto mins  = v.at("min_lens").to_vector<std::size_t>();
            auto maxes = v.at("max_lens").to_vector<std::size_t>();
            auto opts  = v.at("opt_lens").to_vector<std::size_t>();
charlie's avatar
charlie committed
535
            assert(mins.size() == maxes.size() and maxes.size() == opts.size());
536
            auto num_dims = mins.size();
charlie's avatar
charlie committed
537
            std::vector<shape::dynamic_dimension> dyn_dims(num_dims);
538
            for(int i = 0; i < num_dims; ++i)
539
            {
540
                dyn_dims.at(i) = shape::dynamic_dimension{mins[i], maxes[i], opts[i]};
541
            }
542
            s = shape{shape::parse_type(t), dyn_dims};
543
544
545
546
547
548
549
        }
        else
        {
            s = shape{shape::parse_type(t),
                      v.at("lens").to_vector<std::size_t>(),
                      v.at("strides").to_vector<std::size_t>()};
        }
Paul Fultz II's avatar
Paul Fultz II committed
550
    }
551
552
}

Paul's avatar
Paul committed
553
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
554
} // namespace migraphx