shape.cpp 12.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Paul's avatar
Paul committed
24

Paul's avatar
Paul committed
25
26
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
27
#include <migraphx/serialize.hpp>
28
#include <migraphx/permutation.hpp>
Paul's avatar
Paul committed
29
30
31
#include <numeric>
#include <algorithm>
#include <functional>
32
#include <unordered_map>
Paul's avatar
Paul committed
33
#include <iostream>
Paul's avatar
Paul committed
34

Paul's avatar
Paul committed
35
namespace migraphx {
Paul's avatar
Paul committed
36
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
37

Paul's avatar
Paul committed
38
39
struct shape_impl
{
Paul's avatar
Paul committed
40
41
    static std::shared_ptr<shape_impl> default_shape()
    {
42
        static const std::shared_ptr<shape_impl> result = std::make_shared<shape_impl>();
Paul's avatar
Paul committed
43
44
45
        return result;
    }

Paul Fultz II's avatar
Paul Fultz II committed
46
    shape_impl() : m_type(shape::float_type) {}
Paul's avatar
Paul committed
47

Paul Fultz II's avatar
Paul Fultz II committed
48
49
50
51
    shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true)
    {
        assert(t != shape::tuple_type);
    }
Paul's avatar
Paul committed
52
53
54
    shape_impl(shape::type_t t, std::vector<std::size_t> l)
        : m_type(t), m_lens(std::move(l)), m_standard(true)
    {
Paul Fultz II's avatar
Paul Fultz II committed
55
        assert(t != shape::tuple_type);
Paul's avatar
Paul committed
56
57
58
59
60
61
        this->calculate_strides();
        assert(m_lens.size() == m_strides.size());
    }
    shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
        : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
    {
Paul Fultz II's avatar
Paul Fultz II committed
62
        assert(t != shape::tuple_type);
Paul's avatar
Paul committed
63
        assert(m_lens.size() == m_strides.size());
64
        m_standard = this->elements() == this->element_space() and not skips() and
65
                     std::is_sorted(m_strides.rbegin(), m_strides.rend());
Paul's avatar
Paul committed
66
    }
Paul Fultz II's avatar
Paul Fultz II committed
67
68

    shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
Paul's avatar
Paul committed
69
    shape::type_t m_type;
Paul Fultz II's avatar
Paul Fultz II committed
70
71
72
73
    std::vector<std::size_t> m_lens    = {};
    std::vector<std::size_t> m_strides = {};
    std::vector<shape> m_shapes        = {};
    bool m_standard                    = false;
Paul's avatar
Paul committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    void calculate_strides()
    {
        m_strides.clear();
        m_strides.resize(m_lens.size(), 0);
        if(m_strides.empty())
            return;
        m_strides.back() = 1;
        std::partial_sum(m_lens.rbegin(),
                         m_lens.rend() - 1,
                         m_strides.rbegin() + 1,
                         std::multiplies<std::size_t>());
    }

    std::size_t element_space() const
    {
        assert(m_lens.size() == m_strides.size());
        if(m_lens.empty())
            return 0;
        return std::inner_product(m_lens.begin(),
                                  m_lens.end(),
                                  m_strides.begin(),
                                  std::size_t{0},
                                  std::plus<std::size_t>{},
                                  [](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
               1;
    }

    std::size_t elements() const
    {
        assert(m_lens.size() == m_strides.size());
        if(m_lens.empty())
            return 0;
        return std::accumulate(
            m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
    }
110

111
112
113
114
115
116
117
118
119
    // Does the shape skip over elements?
    bool skips() const
    {
        assert(m_lens.size() == m_strides.size());
        if(elements() == 1)
            return false;
        return std::none_of(m_strides.begin(), m_strides.end(), [](auto x) { return x == 1; });
    }

120
    std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
Paul's avatar
Paul committed
121
};
Paul's avatar
Paul committed
122

123
124
125
126
const std::vector<shape::type_t>& shape::types()
{
    static const std::vector<shape::type_t> result = {
#define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x,
Paul Fultz II's avatar
Paul Fultz II committed
127
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type};
128
129
130
    return result;
}

131
132
133
134
std::string shape::name(shape::type_t t)
{
    switch(t)
    {
Paul Fultz II's avatar
Paul Fultz II committed
135
    case tuple_type: return "tuple_type";
136
137
138
139
140
141
142
143
144
145
146
#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \
    case x: return #x;
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE
    }
    MIGRAPHX_THROW("Invalid type");
}
std::string shape::cpp_type(shape::type_t t)
{
    switch(t)
    {
Paul Fultz II's avatar
Paul Fultz II committed
147
    case tuple_type: MIGRAPHX_THROW("No C++ type for tuple");
148
149
150
151
152
153
154
155
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
    case x: return #t;
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE
    }
    MIGRAPHX_THROW("Invalid type");
}

Paul's avatar
Paul committed
156
157
158
shape::shape() : impl(shape_impl::default_shape()) {}

shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
Paul's avatar
Paul committed
159
shape::shape(type_t t, std::vector<std::size_t> l)
Paul's avatar
Paul committed
160
    : impl(std::make_shared<shape_impl>(t, std::move(l)))
Paul's avatar
Paul committed
161
162
163
{
}
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
Paul's avatar
Paul committed
164
    : impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
Paul's avatar
Paul committed
165
166
167
{
}

Paul Fultz II's avatar
Paul Fultz II committed
168
169
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}

170
171
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}

172
173
174
175
176
177
178
179
180
181
shape shape::from_permutation(type_t t,
                              const std::vector<std::size_t>& l,
                              const std::vector<int64_t>& perm)
{
    auto new_lens = reorder_dims(l, perm);
    shape result  = reorder_shape({t, new_lens}, invert_permutation(perm));
    assert(result.lens() == l);
    return result;
}

Paul's avatar
Paul committed
182
183
184
shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
Paul's avatar
Paul committed
185
std::size_t shape::elements() const { return impl->elements(); }
Paul's avatar
Paul committed
186
187
std::size_t shape::bytes() const
{
Paul Fultz II's avatar
Paul Fultz II committed
188
189
190
191
192
193
194
195
196
197
198
199
200
    if(this->sub_shapes().empty())
    {
        std::size_t n = 0;
        this->visit_type([&](auto as) { n = as.size(); });
        return n * this->element_space();
    }
    else
    {
        return std::accumulate(this->sub_shapes().begin(),
                               this->sub_shapes().end(),
                               std::size_t{0},
                               [&](auto x, auto y) { return x + y.bytes(); });
    }
Paul's avatar
Paul committed
201
}
Scott Thornton's avatar
Scott Thornton committed
202
203
204
std::size_t shape::type_size() const
{
    std::size_t n = 0;
Paul Fultz II's avatar
Paul Fultz II committed
205
206
    if(this->sub_shapes().empty())
        this->visit_type([&](auto as) { n = as.size(); });
Scott Thornton's avatar
Scott Thornton committed
207
208
    return n;
}
Paul's avatar
Paul committed
209
210
211
212
213
214
215
216
217
218
219
220
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::index(const std::vector<std::size_t>& l) const
{
    assert(l.size() <= this->lens().size());
    assert(this->lens().size() == this->strides().size());
    return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
Paul's avatar
Paul committed
221
222
223
std::size_t shape::index(std::size_t i) const
{
    assert(this->lens().size() == this->strides().size());
Paul's avatar
Paul committed
224
    if(this->standard())
Paul's avatar
Paul committed
225
226
        return i;
    else
Paul's avatar
Paul committed
227
    {
Paul's avatar
Paul committed
228
        std::size_t s      = 1;
Paul's avatar
Paul committed
229
        std::size_t result = 0;
Paul's avatar
Paul committed
230
        for(std::size_t j = 0; j < this->lens().size(); j++)
Paul's avatar
Paul committed
231
        {
Paul's avatar
Paul committed
232
            const std::size_t k      = this->lens().size() - j - 1;
Paul's avatar
Paul committed
233
            const std::size_t stride = this->strides()[k];
Paul's avatar
Paul committed
234
235
            const std::size_t len    = this->lens()[k];
            const std::size_t idx    = (i % (s * len)) / s;
Paul's avatar
Paul committed
236
237
238
239
240
            result += stride * idx;
            s *= len;
        }
        return result;
    }
Paul's avatar
Paul committed
241
}
242
243
244
245
246
247

std::vector<std::size_t> shape::multi(std::size_t i) const
{
    assert(this->standard());

    std::vector<std::size_t> indices(lens().size());
248
249
250
251
252
253
254
255
256
257
    multi_copy(i, indices.data(), indices.data() + lens().size());

    return indices;
}

void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const
{
    assert(this->standard());
    (void)end;
    assert(lens().size() <= (end - start));
258
    std::transform(strides().begin(),
Shucai Xiao's avatar
Shucai Xiao committed
259
260
                   strides().end(),
                   lens().begin(),
261
                   start,
Shucai Xiao's avatar
Shucai Xiao committed
262
263
264
265
                   [&](std::size_t stride, std::size_t len) {
                       assert(len > 0 and stride > 0);
                       return (i / stride) % len;
                   });
266
267
}

Paul Fultz II's avatar
Paul Fultz II committed
268
269
bool shape::packed() const
{
270
271
    return this->sub_shapes().empty() and not impl->skips() and
           this->elements() == this->element_space();
Paul Fultz II's avatar
Paul Fultz II committed
272
}
Paul's avatar
Paul committed
273

Paul's avatar
Paul committed
274
275
bool shape::transposed() const
{
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    if(this->broadcasted())
    {
        // TODO: Use a filter_iterator instead
        std::vector<std::size_t> s;
        s.reserve(this->strides().size());
        std::copy_if(this->strides().begin(),
                     this->strides().end(),
                     std::back_inserter(s),
                     [](std::size_t x) { return x != 0; });
        return not std::is_sorted(s.rbegin(), s.rend());
    }
    else
    {
        return not std::is_sorted(this->strides().rbegin(), this->strides().rend());
    }
Paul's avatar
Paul committed
291
}
Paul's avatar
Paul committed
292
293
294
295

bool shape::broadcasted() const
{
    assert(this->lens().size() == this->strides().size());
296
297
    return std::any_of(
        this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; });
Paul's avatar
Paul committed
298
299
}

Khalique's avatar
Khalique committed
300
301
302
303
bool shape::scalar() const
{
    assert(this->lens().size() == this->strides().size());
    // if any stride > 0, then accumulate will return false
Paul Fultz II's avatar
Paul Fultz II committed
304
305
    return this->sub_shapes().empty() and
           std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
Khalique's avatar
Khalique committed
306
307
}

Paul's avatar
Paul committed
308
bool shape::standard() const { return impl->m_standard; }
Paul's avatar
Paul committed
309

310
311
312
313
314
315
316
317
shape shape::normalize_standard() const
{
    if(this->standard())
        return {this->type(), this->lens()};
    else
        return *this;
}

318
319
320
321
322
323
324
325
326
327
328
329
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
{
    assert(l.size() == this->lens().size());
    auto perm = find_permutation(*this);
    return shape::from_permutation(t, l, perm);
}

shape shape::with_lens(const std::vector<std::size_t>& l) const
{
    return this->with_lens(this->type(), l);
}

330
331
332
333
334
335
336
shape shape::with_type(type_t t) const
{
    auto c    = impl->copy();
    c->m_type = t;
    return {c};
}

Paul's avatar
Paul committed
337
std::size_t shape::element_space() const { return impl->element_space(); }
Paul's avatar
Paul committed
338

339
std::string shape::type_string() const { return name(this->type()); }
Paul's avatar
Paul committed
340

Paul's avatar
Paul committed
341
342
bool operator==(const shape& x, const shape& y)
{
Paul Fultz II's avatar
Paul Fultz II committed
343
344
    return x.impl == y.impl or (x.type() == y.type() and x.lens() == y.lens() and
                                x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
Paul's avatar
Paul committed
345
}
Paul's avatar
Paul committed
346
bool operator!=(const shape& x, const shape& y) { return !(x == y); }
Paul's avatar
Paul committed
347

Paul's avatar
Paul committed
348
349
std::ostream& operator<<(std::ostream& os, const shape& x)
{
Paul Fultz II's avatar
Paul Fultz II committed
350
351
352
353
354
355
356
357
358
359
    if(x.sub_shapes().empty())
    {
        os << x.type_string() << ", ";
        os << "{" << to_string_range(x.lens()) << "}, ";
        os << "{" << to_string_range(x.strides()) << "}";
    }
    else
    {
        os << "[" << to_string_range(x.sub_shapes()) << "]";
    }
Paul's avatar
Paul committed
360
361
362
    return os;
}

363
364
shape::type_t shape::parse_type(const std::string& s)
{
365
    static const std::unordered_map<std::string, shape::type_t> m = {
366
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x},
Paul Fultz II's avatar
Paul Fultz II committed
367
368
369
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type",
                                                                            tuple_type},
        {"tuple", tuple_type}};
370
371
372
    return m.at(s);
}

Paul Fultz II's avatar
Paul Fultz II committed
373
374
const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }

375
376
377
void migraphx_to_value(value& v, const shape& s)
{
    value result;
Paul Fultz II's avatar
Paul Fultz II committed
378
379
380
381
382
    result["type"]       = migraphx::to_value(s.type_string());
    result["lens"]       = migraphx::to_value(s.lens());
    result["strides"]    = migraphx::to_value(s.strides());
    result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
    v                    = result;
383
384
385
}
void migraphx_from_value(const value& v, shape& s)
{
Paul Fultz II's avatar
Paul Fultz II committed
386
387
388
389
390
391
392
393
394
395
396
    auto t = v.at("type").get_string();
    if(t == "tuple_type")
    {
        s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))};
    }
    else
    {
        s = shape{shape::parse_type(t),
                  v.at("lens").to_vector<std::size_t>(),
                  v.at("strides").to_vector<std::size_t>()};
    }
397
398
}

Paul's avatar
Paul committed
399
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
400
} // namespace migraphx