tensor_view.hpp 4.25 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPH_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH_GUARD_TENSOR_VIEW_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
5
6
#include <migraph/shape.hpp>
#include <migraph/float_equal.hpp>
#include <migraph/requires.hpp>
7
#include <migraph/config.hpp>
Paul's avatar
Paul committed
8
9

#include <iostream>
Paul's avatar
Paul committed
10
#include <utility>
Paul's avatar
Paul committed
11

12
13
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
template <class T>
Paul's avatar
Paul committed
16
17
struct tensor_view
{
18
    using value_type = T;
Paul's avatar
Paul committed
19
    tensor_view() : m_data(nullptr) {}
Paul's avatar
Paul committed
20
    tensor_view(shape s, T* d) : m_data(d), m_shape(std::move(s)) {}
Paul's avatar
Paul committed
21

Paul's avatar
Paul committed
22
    const shape& get_shape() const { return this->m_shape; }
Paul's avatar
Paul committed
23

Paul's avatar
Paul committed
24
    bool empty() const { return m_data == nullptr || m_shape.lens().empty(); }
Paul's avatar
Paul committed
25

Paul's avatar
Paul committed
26
    std::size_t size() const { return m_shape.elements(); }
Paul's avatar
Paul committed
27

Paul's avatar
Paul committed
28
    T* data() { return this->m_data; }
Paul's avatar
Paul committed
29

Paul's avatar
Paul committed
30
    const T* data() const { return this->m_data; }
Paul's avatar
Paul committed
31

Paul's avatar
Paul committed
32
    template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
Paul's avatar
Paul committed
33
34
    const T& operator()(Ts... xs) const
    {
Paul's avatar
Paul committed
35
        assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
Scott Thornton's avatar
Scott Thornton committed
36
        assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
Paul's avatar
Paul committed
37
        return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
Paul's avatar
Paul committed
38
39
    }

Paul's avatar
Paul committed
40
    template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
Paul's avatar
Paul committed
41
42
    T& operator()(Ts... xs)
    {
Paul's avatar
Paul committed
43
        assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
Scott Thornton's avatar
Scott Thornton committed
44
        assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
Paul's avatar
Paul committed
45
        return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
Paul's avatar
Paul committed
46
47
    }

Paul's avatar
Paul committed
48
    template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})>
Paul's avatar
Paul committed
49
50
51
52
53
    const T& operator()(Iterator start, Iterator last) const
    {
        return m_data[m_shape.index(start, last)];
    }

Paul's avatar
Paul committed
54
    template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})>
Paul's avatar
Paul committed
55
56
57
58
59
    T& operator()(Iterator start, Iterator last)
    {
        return m_data[m_shape.index(start, last)];
    }

Paul's avatar
Paul committed
60
61
62
    T& operator[](std::size_t i)
    {
        assert(!this->empty() && i < this->size());
Paul's avatar
Paul committed
63
        return m_data[m_shape.index(i)];
Paul's avatar
Paul committed
64
65
66
67
68
    }

    const T& operator[](std::size_t i) const
    {
        assert(!this->empty() && i < this->size());
Paul's avatar
Paul committed
69
        return m_data[m_shape.index(i)];
Paul's avatar
Paul committed
70
71
72
73
74
    }

    T& front()
    {
        assert(!this->empty());
Paul's avatar
Paul committed
75
        return m_data[0];
Paul's avatar
Paul committed
76
77
78
79
80
    }

    const T& front() const
    {
        assert(!this->empty());
Paul's avatar
Paul committed
81
        return m_data[0];
Paul's avatar
Paul committed
82
83
84
85
86
    }

    T& back()
    {
        assert(!this->empty());
Paul's avatar
Paul committed
87
        return m_data[m_shape.index(this->size() - 1)];
Paul's avatar
Paul committed
88
89
90
91
92
    }

    const T& back() const
    {
        assert(!this->empty());
Paul's avatar
Paul committed
93
        return m_data[m_shape.index(this->size() - 1)];
Paul's avatar
Paul committed
94
95
    }

Paul's avatar
Paul committed
96
    // TODO: Add iterators so it can handle nonstandard tensors
Paul's avatar
Paul committed
97
98
    T* begin()
    {
Paul's avatar
Paul committed
99
        assert(this->m_shape.standard() or this->empty());
Paul's avatar
Paul committed
100
        return m_data;
Paul's avatar
Paul committed
101
102
103
104
    }

    T* end()
    {
Paul's avatar
Paul committed
105
        assert(this->m_shape.standard() or this->empty());
Paul's avatar
Paul committed
106
        if(this->empty())
Paul's avatar
Paul committed
107
            return m_data;
Paul's avatar
Paul committed
108
        else
Paul's avatar
Paul committed
109
            return m_data + this->size();
Paul's avatar
Paul committed
110
111
112
113
    }

    const T* begin() const
    {
Paul's avatar
Paul committed
114
        assert(this->m_shape.standard() or this->empty());
Paul's avatar
Paul committed
115
        return m_data;
Paul's avatar
Paul committed
116
117
118
119
    }

    const T* end() const
    {
Paul's avatar
Paul committed
120
        assert(this->m_shape.standard() or this->empty());
Paul's avatar
Paul committed
121
        if(this->empty())
Paul's avatar
Paul committed
122
            return m_data;
Paul's avatar
Paul committed
123
        else
Paul's avatar
Paul committed
124
            return m_data + this->size();
Paul's avatar
Paul committed
125
126
    }

Paul's avatar
Paul committed
127
128
129
130
131
    friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
    {
        if(!x.empty())
        {
            os << x.front();
Paul's avatar
Paul committed
132
            for(std::size_t i = 1; i < x.m_shape.elements(); i++)
Paul's avatar
Paul committed
133
            {
Paul's avatar
Paul committed
134
                os << ", " << x.m_data[x.m_shape.index(i)];
Paul's avatar
Paul committed
135
136
137
138
139
            }
        }
        return os;
    }

Paul's avatar
Paul committed
140
    private:
Paul's avatar
Paul committed
141
142
    T* m_data;
    shape m_shape;
Paul's avatar
Paul committed
143
144
};

Paul's avatar
Paul committed
145
template <class T, class U>
146
147
148
149
150
151
152
153
154
155
156
157
158
159
bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
{
    if(x.get_shape() == y.get_shape())
    {
        for(std::size_t i = 0; i < x.get_shape().elements(); i++)
        {
            if(!float_equal(x[i], y[i]))
                return false;
        }
        return true;
    }
    return false;
}

Paul's avatar
Paul committed
160
161
162
163
164
template <class T, class U>
bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y)
{
    return !(x == y);
}
165

Paul's avatar
Paul committed
166
template <class T>
Paul's avatar
Paul committed
167
168
169
170
171
tensor_view<T> make_view(shape s, T* data)
{
    return {s, data};
}

172
} // namespace MIGRAPH_INLINE_NS
Paul's avatar
Paul committed
173
} // namespace migraph
Paul's avatar
Paul committed
174
175

#endif