"git@developer.sourcefind.cn:Fzc7075/nunchaku.git" did not exist on "67723598f522456820f8837d95626ddeca74b33b"
shape.cpp 22.4 KB
Newer Older
1
2
3
/*
 * The MIT License (MIT)
 *
4
 * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Paul's avatar
Paul committed
24

Paul's avatar
Paul committed
25
26
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
27
#include <migraphx/serialize.hpp>
28
#include <migraphx/permutation.hpp>
Charlie Lin's avatar
Charlie Lin committed
29
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
30
31
32
#include <numeric>
#include <algorithm>
#include <functional>
33
#include <unordered_map>
Paul's avatar
Paul committed
34
#include <iostream>
Paul's avatar
Paul committed
35

Paul's avatar
Paul committed
36
namespace migraphx {
Paul's avatar
Paul committed
37
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
38

Paul's avatar
Paul committed
39
40
struct shape_impl
{
Paul's avatar
Paul committed
41
42
    static std::shared_ptr<shape_impl> default_shape()
    {
43
        static const std::shared_ptr<shape_impl> result = std::make_shared<shape_impl>();
Paul's avatar
Paul committed
44
45
46
        return result;
    }

Paul Fultz II's avatar
Paul Fultz II committed
47
    shape_impl() : m_type(shape::float_type) {}
Paul's avatar
Paul committed
48

Paul Fultz II's avatar
Paul Fultz II committed
49
50
51
52
    shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true)
    {
        assert(t != shape::tuple_type);
    }
53

Paul's avatar
Paul committed
54
55
56
    shape_impl(shape::type_t t, std::vector<std::size_t> l)
        : m_type(t), m_lens(std::move(l)), m_standard(true)
    {
Paul Fultz II's avatar
Paul Fultz II committed
57
        assert(t != shape::tuple_type);
Paul's avatar
Paul committed
58
59
        this->calculate_strides();
    }
60

Paul's avatar
Paul committed
61
62
63
    shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
        : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
    {
Paul Fultz II's avatar
Paul Fultz II committed
64
        assert(t != shape::tuple_type);
Paul's avatar
Paul committed
65
        assert(m_lens.size() == m_strides.size());
66
        m_standard = this->elements() == this->element_space() and not skips() and
67
                     std::is_sorted(m_strides.rbegin(), m_strides.rend());
Paul's avatar
Paul committed
68
    }
Paul Fultz II's avatar
Paul Fultz II committed
69

Charlie Lin's avatar
Charlie Lin committed
70
71
72
73
74
    shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims)
        : m_type(t), m_dyn_dims(std::move(dims))
    {
    }

75
76
77
    shape_impl(shape::type_t t,
               std::vector<std::size_t> mins,
               std::vector<std::size_t> maxes,
78
               std::vector<std::set<std::size_t>> optimals_list)
79
80
        : m_type(t)
    {
81
        if(optimals_list.empty())
82
        {
83
84
85
86
87
88
89
90
91
92
93
94
            for(size_t i = 0; i < mins.size(); ++i)
            {
                m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i]});
            }
        }
        else
        {
            assert(mins.size() == maxes.size() and maxes.size() == optimals_list.size());
            for(size_t i = 0; i < mins.size(); ++i)
            {
                m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], optimals_list[i]});
            }
95
96
97
        }
    }

Paul Fultz II's avatar
Paul Fultz II committed
98
    shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
Charlie Lin's avatar
Charlie Lin committed
99

Paul's avatar
Paul committed
100
    shape::type_t m_type;
Paul Fultz II's avatar
Paul Fultz II committed
101
102
103
104
    std::vector<std::size_t> m_lens    = {};
    std::vector<std::size_t> m_strides = {};
    std::vector<shape> m_shapes        = {};
    bool m_standard                    = false;
Paul's avatar
Paul committed
105

Charlie Lin's avatar
Charlie Lin committed
106
107
    std::vector<shape::dynamic_dimension> m_dyn_dims = {};

Paul's avatar
Paul committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    void calculate_strides()
    {
        m_strides.clear();
        m_strides.resize(m_lens.size(), 0);
        if(m_strides.empty())
            return;
        m_strides.back() = 1;
        std::partial_sum(m_lens.rbegin(),
                         m_lens.rend() - 1,
                         m_strides.rbegin() + 1,
                         std::multiplies<std::size_t>());
    }

    std::size_t element_space() const
    {
Charlie Lin's avatar
Charlie Lin committed
123
124
125
126
127
128
        if(not m_dyn_dims.empty())
        {
            auto maxes = max_lens();
            return std::accumulate(maxes.begin(), maxes.end(), std::size_t{1}, std::multiplies<>());
        }

Paul's avatar
Paul committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        assert(m_lens.size() == m_strides.size());
        if(m_lens.empty())
            return 0;
        return std::inner_product(m_lens.begin(),
                                  m_lens.end(),
                                  m_strides.begin(),
                                  std::size_t{0},
                                  std::plus<std::size_t>{},
                                  [](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
               1;
    }

    std::size_t elements() const
    {
Charlie Lin's avatar
Charlie Lin committed
143
144
145
146
147
        if(not m_dyn_dims.empty())
        {
            MIGRAPHX_THROW("SHAPE: elements() called on dynamic shape");
        }

Paul's avatar
Paul committed
148
149
150
151
152
153
        assert(m_lens.size() == m_strides.size());
        if(m_lens.empty())
            return 0;
        return std::accumulate(
            m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
    }
154

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    std::size_t get_index(size_t i) const
    {
        std::size_t result = 0;
        std::size_t s      = 1;

        for(auto k : migraphx::reverse(migraphx::range(m_lens.size())))
        {
            std::size_t stride = m_strides[k];
            std::size_t len    = m_lens[k];
            std::size_t idx    = (i % (s * len)) / s;
            result += stride * idx;
            s *= len;
        }
        return result;
    }

Charlie Lin's avatar
Charlie Lin committed
171
172
173
174
175
176
    std::vector<std::size_t> min_lens() const
    {
        std::vector<std::size_t> ret(m_dyn_dims.size());
        std::transform(m_dyn_dims.cbegin(),
                       m_dyn_dims.cend(),
                       ret.begin(),
177
                       [](const shape::dynamic_dimension& x) { return x.min; });
Charlie Lin's avatar
Charlie Lin committed
178
179
180
181
182
183
184
185
186
        return ret;
    }

    std::vector<std::size_t> max_lens() const
    {
        std::vector<std::size_t> ret(m_dyn_dims.size());
        std::transform(m_dyn_dims.cbegin(),
                       m_dyn_dims.cend(),
                       ret.begin(),
187
                       [](const shape::dynamic_dimension& x) { return x.max; });
Charlie Lin's avatar
Charlie Lin committed
188
189
190
        return ret;
    }

191
    std::vector<std::set<std::size_t>> opt_lens() const
Charlie Lin's avatar
Charlie Lin committed
192
    {
193
        std::vector<std::set<std::size_t>> ret(m_dyn_dims.size());
Charlie Lin's avatar
Charlie Lin committed
194
195
196
        std::transform(m_dyn_dims.cbegin(),
                       m_dyn_dims.cend(),
                       ret.begin(),
197
                       [](const shape::dynamic_dimension& x) { return x.optimals; });
Charlie Lin's avatar
Charlie Lin committed
198
199
        return ret;
    }
200

201
202
203
204
205
206
207
208
209
    // Does the shape skip over elements?
    bool skips() const
    {
        assert(m_lens.size() == m_strides.size());
        if(elements() == 1)
            return false;
        return std::none_of(m_strides.begin(), m_strides.end(), [](auto x) { return x == 1; });
    }

210
    std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
Paul's avatar
Paul committed
211
};
Paul's avatar
Paul committed
212

213
214
215
216
const std::vector<shape::type_t>& shape::types()
{
    static const std::vector<shape::type_t> result = {
#define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x,
Paul Fultz II's avatar
Paul Fultz II committed
217
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type};
218
219
220
    return result;
}

221
222
223
224
std::string shape::name(shape::type_t t)
{
    switch(t)
    {
Paul Fultz II's avatar
Paul Fultz II committed
225
    case tuple_type: return "tuple_type";
226
227
228
229
230
231
232
#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \
    case x: return #x;
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE
    }
    MIGRAPHX_THROW("Invalid type");
}
233

234
235
236
237
std::string shape::cpp_type(shape::type_t t)
{
    switch(t)
    {
Paul Fultz II's avatar
Paul Fultz II committed
238
    case tuple_type: MIGRAPHX_THROW("No C++ type for tuple");
239
240
241
242
243
244
245
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
    case x: return #t;
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE
    }
    MIGRAPHX_THROW("Invalid type");
}
Paul's avatar
Paul committed
246
247
248
bool shape::is_integral(shape::type_t t)
{
    bool result = false;
Paul's avatar
Format  
Paul committed
249
    visit(t, [&](auto as) { result = as.is_integral(); });
Paul's avatar
Paul committed
250
251
    return result;
}
252

Paul's avatar
Paul committed
253
254
255
shape::shape() : impl(shape_impl::default_shape()) {}

shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
256

Paul's avatar
Paul committed
257
shape::shape(type_t t, std::vector<std::size_t> l)
Paul's avatar
Paul committed
258
    : impl(std::make_shared<shape_impl>(t, std::move(l)))
Paul's avatar
Paul committed
259
260
{
}
261

Paul's avatar
Paul committed
262
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
Paul's avatar
Paul committed
263
    : impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
Paul's avatar
Paul committed
264
265
266
{
}

Charlie Lin's avatar
Charlie Lin committed
267
268
269
270
271
272
273
274
275
276
shape::shape(type_t t, std::initializer_list<std::size_t> d)
    : shape::shape(t, std::vector<std::size_t>{d.begin(), d.end()})
{
}

shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
    : impl(std::make_shared<shape_impl>(t, std::move(dims)))
{
}

277
278
279
shape::shape(type_t t,
             std::vector<std::size_t> mins,
             std::vector<std::size_t> maxes,
280
281
282
             std::vector<std::set<std::size_t>> optimals_list)
    : impl(std::make_shared<shape_impl>(
          t, std::move(mins), std::move(maxes), std::move(optimals_list)))
283
284
285
{
}

Paul Fultz II's avatar
Paul Fultz II committed
286
287
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}

288
289
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}

290
291
292
293
294
295
296
297
298
299
shape shape::from_permutation(type_t t,
                              const std::vector<std::size_t>& l,
                              const std::vector<int64_t>& perm)
{
    auto new_lens = reorder_dims(l, perm);
    shape result  = reorder_shape({t, new_lens}, invert_permutation(perm));
    assert(result.lens() == l);
    return result;
}

Paul's avatar
Paul committed
300
shape::type_t shape::type() const { return impl->m_type; }
Charlie Lin's avatar
Charlie Lin committed
301

302
303
304
305
306
307
308
309
const std::vector<std::size_t>& shape::lens() const
{
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: lens() called on a dynamic shape");
    }
    return impl->m_lens;
}
Charlie Lin's avatar
Charlie Lin committed
310

311
312
313
314
315
316
317
318
const std::vector<std::size_t>& shape::strides() const
{
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: strides() called on a dynamic shape");
    }
    return impl->m_strides;
}
Charlie Lin's avatar
Charlie Lin committed
319

320
321
322
323
324
325
326
327
328
std::size_t shape::ndim() const
{
    if(this->dynamic())
    {
        return dyn_dims().size();
    }
    return lens().size();
}

Paul's avatar
Paul committed
329
std::size_t shape::elements() const { return impl->elements(); }
Charlie Lin's avatar
Charlie Lin committed
330

Paul's avatar
Paul committed
331
332
std::size_t shape::bytes() const
{
Paul Fultz II's avatar
Paul Fultz II committed
333
334
335
336
337
338
339
340
341
342
343
344
345
    if(this->sub_shapes().empty())
    {
        std::size_t n = 0;
        this->visit_type([&](auto as) { n = as.size(); });
        return n * this->element_space();
    }
    else
    {
        return std::accumulate(this->sub_shapes().begin(),
                               this->sub_shapes().end(),
                               std::size_t{0},
                               [&](auto x, auto y) { return x + y.bytes(); });
    }
Paul's avatar
Paul committed
346
}
Charlie Lin's avatar
Charlie Lin committed
347

Scott Thornton's avatar
Scott Thornton committed
348
349
350
std::size_t shape::type_size() const
{
    std::size_t n = 0;
Paul Fultz II's avatar
Paul Fultz II committed
351
352
    if(this->sub_shapes().empty())
        this->visit_type([&](auto as) { n = as.size(); });
Scott Thornton's avatar
Scott Thornton committed
353
354
    return n;
}
Charlie Lin's avatar
Charlie Lin committed
355

Paul's avatar
Paul committed
356
357
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
Charlie Lin's avatar
Charlie Lin committed
358
359
360
361
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
    }
Paul's avatar
Paul committed
362
363
364
365
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
Charlie Lin's avatar
Charlie Lin committed
366

Paul's avatar
Paul committed
367
368
std::size_t shape::index(const std::vector<std::size_t>& l) const
{
Charlie Lin's avatar
Charlie Lin committed
369
370
371
372
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
    }
Paul's avatar
Paul committed
373
374
375
376
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
Charlie Lin's avatar
Charlie Lin committed
377

Paul's avatar
Paul committed
378
379
std::size_t shape::index(std::size_t i) const
{
Charlie Lin's avatar
Charlie Lin committed
380
381
382
383
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
    }
Paul's avatar
Paul committed
384
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
385
    if(this->standard())
Paul's avatar
Paul committed
386
        return i;
387
388

    return impl->get_index(i);
Paul's avatar
Paul committed
389
}
390

391
std::vector<std::size_t> shape::multi(std::size_t idx) const
392
{
393
    assert(idx < elements());
394
    std::vector<std::size_t> indices(lens().size());
395
    multi_copy(idx, indices.data(), indices.data() + lens().size());
396
397
398
    return indices;
}

399
void shape::multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const
400
{
401
    size_t tidx = idx;
402
    (void)end;
403
    assert(idx < elements());
404
    assert(lens().size() <= (end - start));
405
406
407
408
409
410
    for(size_t ii = lens().size() - 1; ii > 0; ii--)
    {
        *(start + ii) = tidx % lens()[ii];
        tidx          = tidx / lens()[ii];
    }
    *start = tidx;
411
412
}

Paul Fultz II's avatar
Paul Fultz II committed
413
414
bool shape::packed() const
{
Charlie Lin's avatar
Charlie Lin committed
415
416
417
418
    if(this->dynamic())
    {
        return false;
    }
419
420
    return this->sub_shapes().empty() and not impl->skips() and
           this->elements() == this->element_space();
Paul Fultz II's avatar
Paul Fultz II committed
421
}
Paul's avatar
Paul committed
422

Paul's avatar
Paul committed
423
424
bool shape::transposed() const
{
Charlie Lin's avatar
Charlie Lin committed
425
426
427
428
    if(this->dynamic())
    {
        return false;
    }
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    if(this->broadcasted())
    {
        // TODO: Use a filter_iterator instead
        std::vector<std::size_t> s;
        s.reserve(this->strides().size());
        std::copy_if(this->strides().begin(),
                     this->strides().end(),
                     std::back_inserter(s),
                     [](std::size_t x) { return x != 0; });
        return not std::is_sorted(s.rbegin(), s.rend());
    }
    else
    {
        return not std::is_sorted(this->strides().rbegin(), this->strides().rend());
    }
Paul's avatar
Paul committed
444
}
Paul's avatar
Paul committed
445
446
447

bool shape::broadcasted() const
{
Charlie Lin's avatar
Charlie Lin committed
448
449
450
451
    if(this->dynamic())
    {
        return false;
    }
Paul's avatar
Paul committed
452
    assert(this->lens().size() == this->strides().size());
453
454
    return std::any_of(
        this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; });
Paul's avatar
Paul committed
455
456
}

Khalique's avatar
Khalique committed
457
458
bool shape::scalar() const
{
Charlie Lin's avatar
Charlie Lin committed
459
460
461
462
    if(this->dynamic())
    {
        return false;
    }
Khalique's avatar
Khalique committed
463
464
    assert(this->lens().size() == this->strides().size());
    // if any stride > 0, then accumulate will return false
Paul Fultz II's avatar
Paul Fultz II committed
465
466
    return this->sub_shapes().empty() and
           std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
Khalique's avatar
Khalique committed
467
468
}

Paul's avatar
Paul committed
469
bool shape::standard() const { return impl->m_standard; }
Paul's avatar
Paul committed
470

471
472
473
474
475
476
477
478
shape shape::normalize_standard() const
{
    if(this->standard())
        return {this->type(), this->lens()};
    else
        return *this;
}

479
480
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
{
Charlie Lin's avatar
Charlie Lin committed
481
482
483
484
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
    }
485
486
487
488
489
490
491
    assert(l.size() == this->lens().size());
    auto perm = find_permutation(*this);
    return shape::from_permutation(t, l, perm);
}

shape shape::with_lens(const std::vector<std::size_t>& l) const
{
Charlie Lin's avatar
Charlie Lin committed
492
493
494
495
    if(this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
    }
496
497
498
    return this->with_lens(this->type(), l);
}

499
500
501
502
503
504
505
shape shape::with_type(type_t t) const
{
    auto c    = impl->copy();
    c->m_type = t;
    return {c};
}

506
507
shape shape::to_dynamic() const
{
508
509
510
511
512
513
514
515
516
    if(not sub_shapes().empty())
    {
        std::vector<shape> subs;
        std::transform(sub_shapes().cbegin(),
                       sub_shapes().cend(),
                       std::back_inserter(subs),
                       [](auto s) { return s.to_dynamic(); });
        return {subs};
    }
517
518
519
520
    if(this->dynamic())
    {
        return *this;
    }
521
    return {type(), lens(), lens(), {}};
522
523
}

524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
shape shape::to_static(std::size_t x) const
{
    if(not sub_shapes().empty())
    {
        std::vector<shape> subs;
        std::transform(sub_shapes().cbegin(),
                       sub_shapes().cend(),
                       std::back_inserter(subs),
                       [&](auto s) { return s.to_static(x); });
        return {subs};
    }
    if(not this->dynamic())
    {
        return *this;
    }
    auto static_lens = this->max_lens();
    std::transform(static_lens.begin(),
                   static_lens.end(),
                   this->dyn_dims().cbegin(),
                   static_lens.begin(),
                   [&](auto sl, auto dd) { return dd.is_fixed() ? sl : x; });
    return {type(), static_lens};
}

Paul's avatar
Paul committed
548
std::size_t shape::element_space() const { return impl->element_space(); }
Paul's avatar
Paul committed
549

550
std::string shape::type_string() const { return name(this->type()); }
Paul's avatar
Paul committed
551

Charlie Lin's avatar
Charlie Lin committed
552
553
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }

Charlie Lin's avatar
Charlie Lin committed
554
555
556
557
558
559
560
561
562
563
564
bool shape::any_of_dynamic() const
{
    if(this->dynamic())
    {
        return true;
    }
    return std::any_of(this->sub_shapes().cbegin(), this->sub_shapes().cend(), [](auto s) {
        return s.any_of_dynamic();
    });
}

565
566
567
568
569
570
571
572
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const
{
    if(not this->dynamic())
    {
        MIGRAPHX_THROW("SHAPE: dyn_dims() called on a static shape");
    }
    return impl->m_dyn_dims;
}
Charlie Lin's avatar
Charlie Lin committed
573
574
575
576
577
578
579
580
581
582
583

std::vector<std::size_t> shape::min_lens() const
{
    return this->dynamic() ? impl->min_lens() : this->lens();
}

std::vector<std::size_t> shape::max_lens() const
{
    return this->dynamic() ? impl->max_lens() : this->lens();
}

584
std::vector<std::set<std::size_t>> shape::opt_lens() const { return impl->opt_lens(); }
Charlie Lin's avatar
Charlie Lin committed
585
586
587

bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; }

588
bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty(); }
Charlie Lin's avatar
Charlie Lin committed
589

Charlie Lin's avatar
Charlie Lin committed
590
591
592
593
shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x)
{
    this->min += x;
    this->max += x;
594
595
596
597
598
599
    std::set<std::size_t> new_optimals;
    std::transform(this->optimals.begin(),
                   this->optimals.end(),
                   std::inserter(new_optimals, new_optimals.begin()),
                   [&x](const auto& opt) { return (opt + x); });
    this->optimals = new_optimals;
Charlie Lin's avatar
Charlie Lin committed
600
601
602
603
604
605
606
607
608
    return *this;
}

shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t& x)
{
    assert(this->min >= x);
    assert(this->max >= x);
    this->min -= x;
    this->max -= x;
609
610
611
612
613
614
615
616
617
    std::set<std::size_t> new_optimals;
    std::transform(this->optimals.begin(),
                   this->optimals.end(),
                   std::inserter(new_optimals, new_optimals.begin()),
                   [&x](const auto& opt) {
                       assert(opt >= x);
                       return (opt - x);
                   });
    this->optimals = new_optimals;
Charlie Lin's avatar
Charlie Lin committed
618
619
620
    return *this;
}

Charlie Lin's avatar
Charlie Lin committed
621
622
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
623
    // don't check optimals if both are fixed
624
    return (x.min == y.min and x.max == y.max and
625
            ((x.is_fixed() and y.is_fixed()) or (x.optimals == y.optimals)));
Charlie Lin's avatar
Charlie Lin committed
626
627
628
629
}

bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
630
    return not(x == y);
Charlie Lin's avatar
Charlie Lin committed
631
632
633
}
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{
634
    os << "[ " << x.min << ", " << x.max << ", {" << migraphx::to_string_range(x.optimals) << "} ]";
Charlie Lin's avatar
Charlie Lin committed
635
636
637
    return os;
}

638
639
640
641
642
643
644
645
bool operator==(const shape::dynamic_dimension& x, const std::size_t& y)
{
    return x.min == y and x.max == y;
}
bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { return y == x; }
bool operator!=(const shape::dynamic_dimension& x, const std::size_t& y) { return not(x == y); }
bool operator!=(const std::size_t& x, const shape::dynamic_dimension& y) { return not(x == y); }

Charlie Lin's avatar
Charlie Lin committed
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
shape::dynamic_dimension operator+(const shape::dynamic_dimension& x, const std::size_t& y)
{
    auto dd = x;
    return dd += y;
}

shape::dynamic_dimension operator+(const std::size_t& x, const shape::dynamic_dimension& y)
{
    return y + x;
}

shape::dynamic_dimension operator-(const shape::dynamic_dimension& x, const std::size_t& y)
{
    auto dd = x;
    return dd -= y;
}

Paul's avatar
Paul committed
663
664
bool operator==(const shape& x, const shape& y)
{
Charlie Lin's avatar
Charlie Lin committed
665
666
667
668
669
670
671
672
    if(x.dynamic() and y.dynamic())
    {
        return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and
                                    x.sub_shapes() == y.sub_shapes());
    }
    return x.impl == y.impl or
           (x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and
            x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
Paul's avatar
Paul committed
673
}
Charlie Lin's avatar
Charlie Lin committed
674

675
bool operator!=(const shape& x, const shape& y) { return not(x == y); }
Paul's avatar
Paul committed
676

Paul's avatar
Paul committed
677
678
std::ostream& operator<<(std::ostream& os, const shape& x)
{
Paul Fultz II's avatar
Paul Fultz II committed
679
680
    if(x.sub_shapes().empty())
    {
Charlie Lin's avatar
Charlie Lin committed
681
682
683
684
685
686
687
688
689
690
691
692
        if(x.dynamic())
        {
            os << "dynamic, ";
            os << x.type_string() << ", ";
            os << "{" << to_string_range(x.dyn_dims()) << "}";
        }
        else
        {
            os << x.type_string() << ", ";
            os << "{" << to_string_range(x.lens()) << "}, ";
            os << "{" << to_string_range(x.strides()) << "}";
        }
Paul Fultz II's avatar
Paul Fultz II committed
693
694
695
696
697
    }
    else
    {
        os << "[" << to_string_range(x.sub_shapes()) << "]";
    }
Paul's avatar
Paul committed
698
699
700
    return os;
}

701
702
shape::type_t shape::parse_type(const std::string& s)
{
703
    static const std::unordered_map<std::string, shape::type_t> m = {
704
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x},
Paul Fultz II's avatar
Paul Fultz II committed
705
706
707
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type",
                                                                            tuple_type},
        {"tuple", tuple_type}};
708
709
710
    return m.at(s);
}

Paul Fultz II's avatar
Paul Fultz II committed
711
712
const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }

713
714
715
void migraphx_to_value(value& v, const shape& s)
{
    value result;
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
    result["type"]       = migraphx::to_value(s.type_string());
    result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
    // avoid calling functions that will throw
    if(s.dynamic())
    {
        result["lens"]               = {};
        result["strides"]            = {};
        result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims());
    }
    else
    {
        result["lens"]               = migraphx::to_value(s.lens());
        result["strides"]            = migraphx::to_value(s.strides());
        result["dynamic_dimensions"] = {};
    }
    v = result;
732
}
Charlie Lin's avatar
Charlie Lin committed
733

734
735
void migraphx_from_value(const value& v, shape& s)
{
Paul Fultz II's avatar
Paul Fultz II committed
736
737
738
739
740
741
742
    auto t = v.at("type").get_string();
    if(t == "tuple_type")
    {
        s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))};
    }
    else
    {
Charlie Lin's avatar
Charlie Lin committed
743
744
745
746
747
748
749
750
751
752
        if(v.at("dynamic_dimensions").empty())
        {
            s = shape{shape::parse_type(t),
                      v.at("lens").to_vector<std::size_t>(),
                      v.at("strides").to_vector<std::size_t>()};
        }
        else
        {
            auto v_dd = v.at("dynamic_dimensions");
            std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
753
754
755
756
            std::transform(
                v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](const migraphx::value& x) {
                    return from_value<shape::dynamic_dimension>(x);
                });
Charlie Lin's avatar
Charlie Lin committed
757
758
759

            s = shape{shape::parse_type(t), dyn_dims};
        }
Paul Fultz II's avatar
Paul Fultz II committed
760
    }
761
762
}

Paul's avatar
Paul committed
763
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
764
} // namespace migraphx