"src/vscode:/vscode.git/clone" did not exist on "21b14ff241bffcf9cb8b4678631bec09706e1aad"
tensor_view.hpp 4.83 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_TENSOR_VIEW_HPP
#define MIGRAPHX_GUARD_TENSOR_VIEW_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
5
6
#include <migraphx/shape.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/requires.hpp>
7
#include <migraphx/iota_iterator.hpp>
Paul's avatar
Paul committed
8
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
9
10

#include <iostream>
Paul's avatar
Paul committed
11
#include <utility>
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
namespace migraphx {
Paul's avatar
Paul committed
14
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
15

Paul's avatar
Paul committed
16
17
18
19
20
21
22
23
template <class T>
T as_number(T x)
{
    return x;
}
inline int32_t as_number(int8_t x) { return static_cast<int32_t>(x); }
inline uint32_t as_number(uint8_t x) { return static_cast<uint32_t>(x); }

24
25
26
27
28
29
30
31
32
33
34
template <class T>
struct tensor_view_iterator_read
{
    T* view;
    auto& operator()(std::size_t n) const
    {
        assert(view != nullptr);
        return (*view)[n];
    }
};

Paul's avatar
Paul committed
35
template <class T>
Paul's avatar
Paul committed
36
37
struct tensor_view
{
38
39
40
41
    using value_type = T;
    using iterator   = basic_iota_iterator<tensor_view_iterator_read<tensor_view<T>>, std::size_t>;
    using const_iterator =
        basic_iota_iterator<tensor_view_iterator_read<const tensor_view<T>>, std::size_t>;
Paul's avatar
Paul committed
42
    tensor_view() : m_data(nullptr) {}
Paul's avatar
Paul committed
43
    tensor_view(shape s, T* d) : m_data(d), m_shape(std::move(s)) {}
Paul's avatar
Paul committed
44

Paul's avatar
Paul committed
45
    const shape& get_shape() const { return this->m_shape; }
Paul's avatar
Paul committed
46

Paul's avatar
Paul committed
47
    bool empty() const { return m_data == nullptr || m_shape.lens().empty(); }
Paul's avatar
Paul committed
48

Paul's avatar
Paul committed
49
    std::size_t size() const { return m_shape.elements(); }
Paul's avatar
Paul committed
50

Paul's avatar
Paul committed
51
    T* data() { return this->m_data; }
Paul's avatar
Paul committed
52

Paul's avatar
Paul committed
53
    const T* data() const { return this->m_data; }
Paul's avatar
Paul committed
54

Paul's avatar
Paul committed
55
    template <class... Ts, MIGRAPHX_REQUIRES(std::is_integral<Ts>{}...)>
Paul's avatar
Paul committed
56
57
    const T& operator()(Ts... xs) const
    {
Paul's avatar
Paul committed
58
        assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
Scott Thornton's avatar
Scott Thornton committed
59
        assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
Paul's avatar
Paul committed
60
        return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
Paul's avatar
Paul committed
61
62
    }

Paul's avatar
Paul committed
63
    template <class... Ts, MIGRAPHX_REQUIRES(std::is_integral<Ts>{}...)>
Paul's avatar
Paul committed
64
65
    T& operator()(Ts... xs)
    {
Paul's avatar
Paul committed
66
        assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
Scott Thornton's avatar
Scott Thornton committed
67
        assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
Paul's avatar
Paul committed
68
        return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
Paul's avatar
Paul committed
69
70
    }

Paul's avatar
Paul committed
71
    template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
Paul's avatar
Paul committed
72
73
    const T& operator()(Iterator start, Iterator last) const
    {
Paul Fultz II's avatar
Paul Fultz II committed
74
75
        assert(std::distance(start, last) > 0);
        assert(std::all_of(start, last, [](auto x) { return x >= 0; }));
Paul's avatar
Paul committed
76
77
78
        return m_data[m_shape.index(start, last)];
    }

Paul's avatar
Paul committed
79
    template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
Paul's avatar
Paul committed
80
81
    T& operator()(Iterator start, Iterator last)
    {
Paul Fultz II's avatar
Paul Fultz II committed
82
83
        assert(std::distance(start, last) > 0);
        assert(std::all_of(start, last, [](auto x) { return x >= 0; }));
Paul's avatar
Paul committed
84
85
86
        return m_data[m_shape.index(start, last)];
    }

Paul's avatar
Paul committed
87
88
89
    T& operator[](std::size_t i)
    {
        assert(!this->empty() && i < this->size());
Paul's avatar
Paul committed
90
        return m_data[m_shape.index(i)];
Paul's avatar
Paul committed
91
92
93
94
95
    }

    const T& operator[](std::size_t i) const
    {
        assert(!this->empty() && i < this->size());
Paul's avatar
Paul committed
96
        return m_data[m_shape.index(i)];
Paul's avatar
Paul committed
97
98
99
100
101
    }

    T& front()
    {
        assert(!this->empty());
Paul's avatar
Paul committed
102
        return m_data[0];
Paul's avatar
Paul committed
103
104
105
106
107
    }

    const T& front() const
    {
        assert(!this->empty());
Paul's avatar
Paul committed
108
        return m_data[0];
Paul's avatar
Paul committed
109
110
111
112
113
    }

    T& back()
    {
        assert(!this->empty());
Paul's avatar
Paul committed
114
        return m_data[m_shape.index(this->size() - 1)];
Paul's avatar
Paul committed
115
116
117
118
119
    }

    const T& back() const
    {
        assert(!this->empty());
Paul's avatar
Paul committed
120
        return m_data[m_shape.index(this->size() - 1)];
Paul's avatar
Paul committed
121
122
    }

123
    iterator begin() { return {0, {this}}; }
Paul's avatar
Paul committed
124

125
    iterator end() { return {this->size(), {this}}; }
Paul's avatar
Paul committed
126

127
    const_iterator begin() const { return {0, {this}}; }
Paul's avatar
Paul committed
128

129
    const_iterator end() const { return {this->size(), {this}}; }
Paul's avatar
Paul committed
130

131
132
133
134
135
    template <class U = T>
    std::vector<U> to_vector() const
    {
        return std::vector<U>(this->begin(), this->end());
    }
Paul's avatar
Paul committed
136

Paul's avatar
Paul committed
137
138
139
140
    friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
    {
        if(!x.empty())
        {
Paul's avatar
Paul committed
141
            os << as_number(x.front());
Paul's avatar
Paul committed
142
            for(std::size_t i = 1; i < x.m_shape.elements(); i++)
Paul's avatar
Paul committed
143
            {
Paul's avatar
Paul committed
144
                os << ", " << as_number(x.m_data[x.m_shape.index(i)]);
Paul's avatar
Paul committed
145
146
147
148
149
            }
        }
        return os;
    }

Paul's avatar
Paul committed
150
    private:
Paul's avatar
Paul committed
151
152
    T* m_data;
    shape m_shape;
Paul's avatar
Paul committed
153
154
};

Paul's avatar
Paul committed
155
template <class T, class U>
156
157
158
159
160
161
162
163
164
165
166
167
168
169
bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
{
    if(x.get_shape() == y.get_shape())
    {
        for(std::size_t i = 0; i < x.get_shape().elements(); i++)
        {
            if(!float_equal(x[i], y[i]))
                return false;
        }
        return true;
    }
    return false;
}

Paul's avatar
Paul committed
170
171
172
173
174
template <class T, class U>
bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y)
{
    return !(x == y);
}
175

Paul's avatar
Paul committed
176
template <class T>
Paul's avatar
Paul committed
177
tensor_view<T> make_view(const shape& s, T* data)
Paul's avatar
Paul committed
178
179
180
181
{
    return {s, data};
}

Paul's avatar
Paul committed
182
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
183
} // namespace migraphx
Paul's avatar
Paul committed
184
185

#endif