literal.hpp 3.77 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <migraph/shape.hpp>
Paul's avatar
Paul committed
5
#include <migraph/shape_for_each.hpp>
Paul's avatar
Paul committed
6
7
8
#include <migraph/argument.hpp>
#include <migraph/tensor_view.hpp>
#include <migraph/raw_data.hpp>
9
#include <migraph/make_shared_array.hpp>
10
#include <migraph/config.hpp>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
#include <memory>

14
15
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
Paul's avatar
Paul committed
16

Paul's avatar
Paul committed
17
18
19
20
/**
 * @brief Represents a raw literal
 * @details This stores the literal has a raw buffer that is owned by this class
 */
Paul's avatar
Paul committed
21
struct literal : raw_data<literal>
Paul's avatar
Paul committed
22
{
Paul's avatar
Paul committed
23
    literal() {}
Paul's avatar
Paul committed
24

Paul's avatar
Paul committed
25
    template <class U, class T = deduce<U>>
Paul's avatar
Paul committed
26
    literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(shape::get_type<T>{})
Paul's avatar
Paul committed
27
    {
Paul's avatar
Paul committed
28
        static_assert(std::is_trivially_copyable<T>{}, "Literals can only be trivial types");
Paul's avatar
Paul committed
29
        *(reinterpret_cast<T*>(buffer.get())) = x;
Paul's avatar
Paul committed
30
31
    }

Paul's avatar
Paul committed
32
    template <class T>
Paul's avatar
Paul committed
33
    literal(const shape& s, const std::vector<T>& x)
34
        : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
Paul's avatar
Paul committed
35
    {
Paul's avatar
Paul committed
36
        static_assert(std::is_trivially_copyable<T>{}, "Literals can only be trivial types");
Paul's avatar
Paul committed
37
        fill(x.begin(), x.end());
Paul's avatar
Paul committed
38
39
    }

Paul's avatar
Paul committed
40
    template <class T>
Paul's avatar
Paul committed
41
    literal(const shape& s, const std::initializer_list<T>& x)
42
        : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
Paul's avatar
Paul committed
43
    {
Paul's avatar
Paul committed
44
        static_assert(std::is_trivially_copyable<T>{}, "Literals can only be trivial types");
Paul's avatar
Paul committed
45
        fill(x.begin(), x.end());
Paul's avatar
Paul committed
46
47
    }

Paul's avatar
Paul committed
48
    template <class Iterator>
Paul's avatar
Paul committed
49
    literal(const shape& s, Iterator start, Iterator end)
50
        : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
Paul's avatar
Paul committed
51
    {
Paul's avatar
Paul committed
52
        fill(start, end);
Paul's avatar
Paul committed
53
54
    }

55
    literal(const shape& s, const char* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
Paul's avatar
Paul committed
56
57
58
    {
        std::copy(x, x + s.bytes(), buffer.get());
    }
Paul's avatar
Paul committed
59

Paul's avatar
Paul committed
60
    /// Whether data is available
Paul's avatar
Paul committed
61
    bool empty() const { return this->buffer == nullptr; }
Paul's avatar
Paul committed
62

Paul's avatar
Paul committed
63
    /// Provides a raw pointer to the data
Paul's avatar
Paul committed
64
    const char* data() const { return this->buffer.get(); }
Paul's avatar
Paul committed
65

Paul's avatar
Paul committed
66
    const shape& get_shape() const { return this->m_shape; }
Paul's avatar
Paul committed
67

Paul's avatar
Paul committed
68
    /// Convert the data to an argument
Paul's avatar
Paul committed
69
70
    argument get_argument() const
    {
Paul's avatar
Paul committed
71
        std::vector<char> b(buffer.get(), buffer.get() + m_shape.bytes());
Paul's avatar
Paul committed
72
        return {m_shape, [b]() mutable { return b.data(); }};
Paul's avatar
Paul committed
73
74
    }

Paul's avatar
Paul committed
75
    private:
Paul's avatar
Paul committed
76
    std::shared_ptr<char> buffer;
Paul's avatar
Paul committed
77
    shape m_shape;
Paul's avatar
Paul committed
78
79
80
81

    template <class Iterator>
    void fill(Iterator start, Iterator end)
    {
Paul's avatar
Paul committed
82
        if(m_shape.standard())
Paul's avatar
Paul committed
83
        {
Paul's avatar
Paul committed
84
            m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
Paul's avatar
Paul committed
85
86
87
88
89
        }
        else
        {
            auto it = start;
            m_shape.visit_type([&](auto as) {
Paul's avatar
Paul committed
90
                auto output = make_view(m_shape, as.from(buffer.get()));
Paul's avatar
Paul committed
91
92
                shape_for_each(output.get_shape(), [&](const auto& idx) {
                    output(idx.begin(), idx.end()) = *it;
Paul's avatar
Paul committed
93
                    it++;
Paul's avatar
Paul committed
94
95
96
97
                });
            });
        }
    }
Paul's avatar
Paul committed
98
99
};

Paul's avatar
Paul committed
100
template <class F>
Paul's avatar
Paul committed
101
102
103
104
105
literal transform(literal l, F f)
{
    literal result;
    l.visit([&](auto x) {
        using type = std::remove_cv_t<typename decltype(x)::value_type>;
Paul's avatar
Paul committed
106
        std::vector<type> output(x.size(), type(0));
Paul's avatar
Paul committed
107
108
109
110
111
112
        std::transform(x.begin(), x.end(), output.begin(), f);
        result = literal{l.get_shape(), output};
    });
    return result;
}

Paul's avatar
Paul committed
113
114
115
116
117
118
119
template <class F>
literal transform(literal l1, literal l2, F f)
{
    assert(l1.get_shape() == l2.get_shape());
    literal result;
    visit_all(l1, l2)([&](auto x, auto y) {
        using type = std::remove_cv_t<typename decltype(x)::value_type>;
Paul's avatar
Paul committed
120
        std::vector<type> output(x.size(), type(0));
Paul's avatar
Paul committed
121
122
123
124
125
126
        std::transform(x.begin(), x.end(), y.begin(), output.begin(), f);
        result = literal{l1.get_shape(), output};
    });
    return result;
}

127
} // namespace MIGRAPH_INLINE_NS
Paul's avatar
Paul committed
128
} // namespace migraph
Paul's avatar
Paul committed
129
130

#endif