literal.hpp 5.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Paul's avatar
Paul committed
24
25
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_LITERAL_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_LITERAL_HPP
Paul's avatar
Paul committed
26

Paul's avatar
Paul committed
27
28
29
30
31
32
33
#include <migraphx/shape.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/tensor_view.hpp>
#include <migraphx/raw_data.hpp>
#include <migraphx/make_shared_array.hpp>
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
34

Paul's avatar
Paul committed
35
36
#include <memory>

Paul's avatar
Paul committed
37
namespace migraphx {
Paul's avatar
Paul committed
38
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
39

Paul's avatar
Paul committed
40
41
42
43
/**
 * @brief Represents a raw literal
 * @details This stores the literal has a raw buffer that is owned by this class
 */
Paul's avatar
Paul committed
44
struct literal : raw_data<literal>
Paul's avatar
Paul committed
45
{
Paul's avatar
Paul committed
46
    literal() {}
Paul's avatar
Paul committed
47

Paul's avatar
Paul committed
48
49
    template <class U, class T = deduce<U>, shape::type_t ShapeType = shape::get_type<T>{}>
    literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(ShapeType)
Paul's avatar
Paul committed
50
    {
Paul's avatar
Paul committed
51
        static_assert(std::is_trivially_copyable<T>{}, "Literals can only be trivial types");
Paul's avatar
Paul committed
52
        *(reinterpret_cast<T*>(buffer.get())) = x;
Paul's avatar
Paul committed
53
54
    }

Paul's avatar
Paul committed
55
    template <class T>
Paul's avatar
Paul committed
56
    literal(const shape& s, const std::vector<T>& x)
57
        : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
Paul's avatar
Paul committed
58
    {
Paul's avatar
Paul committed
59
        static_assert(std::is_trivially_copyable<T>{}, "Literals can only be trivial types");
Paul's avatar
Paul committed
60
        fill(x.begin(), x.end());
Paul's avatar
Paul committed
61
62
    }

Paul's avatar
Paul committed
63
    template <class T>
Paul's avatar
Paul committed
64
    literal(const shape& s, const std::initializer_list<T>& x)
65
        : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
Paul's avatar
Paul committed
66
    {
Paul's avatar
Paul committed
67
        static_assert(std::is_trivially_copyable<T>{}, "Literals can only be trivial types");
Paul's avatar
Paul committed
68
        fill(x.begin(), x.end());
Paul's avatar
Paul committed
69
70
    }

Paul's avatar
Paul committed
71
    template <class Iterator>
Paul's avatar
Paul committed
72
    literal(const shape& s, Iterator start, Iterator end)
73
        : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
Paul's avatar
Paul committed
74
    {
Paul's avatar
Paul committed
75
        fill(start, end);
Paul's avatar
Paul committed
76
77
    }

78
79
    template <class T, MIGRAPHX_REQUIRES(sizeof(T) == 1)>
    literal(const shape& s, T* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
Paul's avatar
Paul committed
80
81
82
    {
        std::copy(x, x + s.bytes(), buffer.get());
    }
Paul's avatar
Paul committed
83

Paul's avatar
Paul committed
84
    /// Whether data is available
Paul's avatar
Paul committed
85
    bool empty() const { return this->buffer == nullptr; }
Paul's avatar
Paul committed
86

Paul's avatar
Paul committed
87
    /// Provides a raw pointer to the data
Paul's avatar
Paul committed
88
    const char* data() const { return this->buffer.get(); }
Paul's avatar
Paul committed
89

Paul's avatar
Paul committed
90
    const shape& get_shape() const { return this->m_shape; }
Paul's avatar
Paul committed
91

Paul Fultz II's avatar
Paul Fultz II committed
92
93
    std::vector<literal> get_sub_objects() const { return {}; }

Paul's avatar
Paul committed
94
    /// Convert the data to an argument
Paul's avatar
Paul committed
95
96
    argument get_argument() const
    {
97
98
        auto b = make_shared_array<char>(buffer.get(), buffer.get() + m_shape.bytes());
        return {m_shape, [b]() { return b.get(); }};
Paul's avatar
Paul committed
99
100
    }

Paul's avatar
Paul committed
101
    private:
Paul's avatar
Paul committed
102
    std::shared_ptr<char> buffer;
Paul's avatar
Paul committed
103
    shape m_shape;
Paul's avatar
Paul committed
104
105
106
107

    template <class Iterator>
    void fill(Iterator start, Iterator end)
    {
Paul's avatar
Paul committed
108
        assert(std::distance(start, end) == m_shape.elements());
Paul's avatar
Paul committed
109
        if(m_shape.standard())
Paul's avatar
Paul committed
110
        {
Paul's avatar
Paul committed
111
            m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
Paul's avatar
Paul committed
112
113
114
115
116
        }
        else
        {
            auto it = start;
            m_shape.visit_type([&](auto as) {
Paul's avatar
Paul committed
117
                auto output = make_view(m_shape, as.from(buffer.get()));
Paul's avatar
Paul committed
118
                shape_for_each(output.get_shape(), [&](const auto& idx) {
119
                    output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
Paul's avatar
Paul committed
120
                    it++;
Paul's avatar
Paul committed
121
122
123
124
                });
            });
        }
    }
Paul's avatar
Paul committed
125
126
};

Paul's avatar
Paul committed
127
template <class F>
Paul's avatar
Paul committed
128
129
130
131
132
literal transform(literal l, F f)
{
    literal result;
    l.visit([&](auto x) {
        using type = std::remove_cv_t<typename decltype(x)::value_type>;
Paul's avatar
Paul committed
133
        std::vector<type> output(x.size(), type(0));
Paul's avatar
Paul committed
134
135
136
137
138
139
        std::transform(x.begin(), x.end(), output.begin(), f);
        result = literal{l.get_shape(), output};
    });
    return result;
}

Paul's avatar
Paul committed
140
141
142
143
144
145
146
template <class F>
literal transform(literal l1, literal l2, F f)
{
    assert(l1.get_shape() == l2.get_shape());
    literal result;
    visit_all(l1, l2)([&](auto x, auto y) {
        using type = std::remove_cv_t<typename decltype(x)::value_type>;
Paul's avatar
Paul committed
147
        std::vector<type> output(x.size(), type(0));
Paul's avatar
Paul committed
148
149
150
151
152
153
        std::transform(x.begin(), x.end(), y.begin(), output.begin(), f);
        result = literal{l1.get_shape(), output};
    });
    return result;
}

154
155
156
void migraphx_to_value(value& v, const literal& l);
void migraphx_from_value(const value& v, literal& l);

Paul's avatar
Paul committed
157
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
158
} // namespace migraphx
Paul's avatar
Paul committed
159
160

#endif