allocate.hpp 4.32 KB
Newer Older
1
2
3
/*
 * The MIT License (MIT)
 *
4
 * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
24
25
26
27
28
29
#ifndef MIGRAPHX_GUARD_OPERATORS_ALLOCATE_HPP
#define MIGRAPHX_GUARD_OPERATORS_ALLOCATE_HPP

#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
Charlie Lin's avatar
Charlie Lin committed
30
#include <migraphx/argument.hpp>
31
32
33
34
35

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

36
37
38
39
/**
 * Static allocate:
 * No inputs: `allocate()`
 * `this.s` attribute set to the static output shape of the buffer.
40
41
 * `this.s` attribute can be set to a dynamic output shape; however this will allocate the maximum
 * buffer size for that case
42
43
44
45
 *
 * Dynamic allocate:
 * One input: `allocate(output_dims)`
 * `output_dims` are the output buffer dimensions and has a static shape.
46
47
48
49
 * Either `this.s` or `this.buf_type` (but not both) must be set to calculate the dynamic output
 * shape at compute time. If `this.buf_type` is set, the compute_shape() of allocate at compile time
 * will have dynamic_dimensions from {0, max_int} with rank = output_dims.ndim(). If `this.s` is set
 * then the compute_shape() will output `this.s`; `this.s` should be a dynamic shape.
50
 */
51
52
struct allocate
{
53
    optional<shape> s;
Charlie Lin's avatar
Charlie Lin committed
54
    // for dynamic allocate to set the buffer type
55
    optional<shape::type_t> buf_type;
Charlie Lin's avatar
Charlie Lin committed
56

57
58
59
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Charlie Lin's avatar
Charlie Lin committed
60
        return pack(f(self.s, "shape"), f(self.buf_type, "buf_type"));
61
    }
Charlie Lin's avatar
Charlie Lin committed
62

63
    std::string name() const { return "allocate"; }
Charlie Lin's avatar
Charlie Lin committed
64

65
66
    shape compute_shape(const std::vector<shape>& inputs) const
    {
67
        if(s.has_value())
Charlie Lin's avatar
Charlie Lin committed
68
        {
69
70
71
72
            if(buf_type.has_value())
            {
                MIGRAPHX_THROW("ALLOCATE: shape and buf_type attributes both set");
            }
73
74
75
76
77
78
79
80
            if(inputs.size() == 1)
            {
                migraphx::check_shapes{inputs, *this, false}.only_dims(1);
            }
            else
            {
                migraphx::check_shapes{inputs, *this, false}.has(0);
            }
81
            return s.value();
Charlie Lin's avatar
Charlie Lin committed
82
83
84
        }
        else
        {
85
86
87
88
            if(not buf_type.has_value())
            {
                MIGRAPHX_THROW("ALLOCATE: shape and buf_type attributes both not set");
            }
89
            migraphx::check_shapes{inputs, *this, false}.has(1).only_dims(1);
Charlie Lin's avatar
Charlie Lin committed
90
91
92
93
            const auto& out_dims = inputs.at(0);
            std::size_t max_val = std::numeric_limits<std::size_t>::max();
            std::vector<shape::dynamic_dimension> dyn_dims(out_dims.lens().at(0),
                                                           shape::dynamic_dimension{0, max_val});
94
            return {buf_type.value(), dyn_dims};
Charlie Lin's avatar
Charlie Lin committed
95
        }
96
    }
Charlie Lin's avatar
Charlie Lin committed
97
    argument compute(const shape& output_shape, const std::vector<argument>& args) const
98
    {
Charlie Lin's avatar
Charlie Lin committed
99
100
        if(args.empty())
        {
101
            return argument{output_shape};
Charlie Lin's avatar
Charlie Lin committed
102
103
104
105
106
        }
        else
        {
            std::vector<std::size_t> output_dims(output_shape.ndim());
            args.at(0).visit([&](auto a) { output_dims.assign(a.begin(), a.end()); });
107
108
109
110
111
            if(s)
            {
                return argument{shape{s->type(), output_dims}};
            }
            return argument{shape{buf_type.value(), output_dims}};
Charlie Lin's avatar
Charlie Lin committed
112
        }
113
114
115
116
117
118
119
120
    }
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif