"vscode:/vscode.git/clone" did not exist on "dd81f9179d8382990d04c1471cf617951eaec203"
tensor_view.hpp 3.07 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#ifndef RTG_GUARD_TENSOR_VIEW_HPP
#define RTG_GUARD_TENSOR_VIEW_HPP

#include <rtg/shape.hpp>
Paul's avatar
Paul committed
5
#include <rtg/float_equal.hpp>
Paul's avatar
Paul committed
6
7
8
9
10

#include <iostream>

namespace rtg {

Paul's avatar
Paul committed
11
template <class T>
Paul's avatar
Paul committed
12
13
struct tensor_view
{
Paul's avatar
Paul committed
14
15
    tensor_view() : m_data(nullptr) {}
    tensor_view(shape s, T* d) : m_data(d), m_shape(s) {}
Paul's avatar
Paul committed
16

Paul's avatar
Paul committed
17
    const shape& get_shape() const { return this->m_shape; }
Paul's avatar
Paul committed
18

Paul's avatar
Paul committed
19
    bool empty() const { return m_data == nullptr || m_shape.lens().empty(); }
Paul's avatar
Paul committed
20

Paul's avatar
Paul committed
21
    std::size_t size() const { return m_shape.elements(); }
Paul's avatar
Paul committed
22

Paul's avatar
Paul committed
23
    T* data() { return this->m_data; }
Paul's avatar
Paul committed
24

Paul's avatar
Paul committed
25
    const T* data() const { return this->m_data; }
Paul's avatar
Paul committed
26
27

    template <class... Ts>
Paul's avatar
Paul committed
28
29
    const T& operator()(Ts... xs) const
    {
Paul's avatar
Paul committed
30
        return m_data[m_shape.index({xs...})];
Paul's avatar
Paul committed
31
32
    }

Paul's avatar
Paul committed
33
    template <class... Ts>
Paul's avatar
Paul committed
34
35
    T& operator()(Ts... xs)
    {
Paul's avatar
Paul committed
36
        return m_data[m_shape.index({xs...})];
Paul's avatar
Paul committed
37
38
39
40
41
    }

    T& operator[](std::size_t i)
    {
        assert(!this->empty() && i < this->size());
Paul's avatar
Paul committed
42
        return m_data[m_shape.index(i)];
Paul's avatar
Paul committed
43
44
45
46
47
    }

    const T& operator[](std::size_t i) const
    {
        assert(!this->empty() && i < this->size());
Paul's avatar
Paul committed
48
        return m_data[m_shape.index(i)];
Paul's avatar
Paul committed
49
50
51
52
53
    }

    T& front()
    {
        assert(!this->empty());
Paul's avatar
Paul committed
54
        return m_data[0];
Paul's avatar
Paul committed
55
56
57
58
59
    }

    const T& front() const
    {
        assert(!this->empty());
Paul's avatar
Paul committed
60
        return m_data[0];
Paul's avatar
Paul committed
61
62
63
64
65
    }

    T& back()
    {
        assert(!this->empty());
Paul's avatar
Paul committed
66
        return m_data[m_shape.index(this->size() - 1)];
Paul's avatar
Paul committed
67
68
69
70
71
    }

    const T& back() const
    {
        assert(!this->empty());
Paul's avatar
Paul committed
72
        return m_data[m_shape.index(this->size() - 1)];
Paul's avatar
Paul committed
73
74
75
76
77
    }

    // TODO: Add iterators so it can handle nonpacked tensors
    T* begin()
    {
Paul's avatar
Paul committed
78
79
        assert(this->m_shape.packed());
        return m_data;
Paul's avatar
Paul committed
80
81
82
83
    }

    T* end()
    {
Paul's avatar
Paul committed
84
        assert(this->m_shape.packed());
Paul's avatar
Paul committed
85
        if(this->empty())
Paul's avatar
Paul committed
86
            return m_data;
Paul's avatar
Paul committed
87
        else
Paul's avatar
Paul committed
88
            return m_data + this->size();
Paul's avatar
Paul committed
89
90
91
92
    }

    const T* begin() const
    {
Paul's avatar
Paul committed
93
94
        assert(this->m_shape.packed());
        return m_data;
Paul's avatar
Paul committed
95
96
97
98
    }

    const T* end() const
    {
Paul's avatar
Paul committed
99
        assert(this->m_shape.packed());
Paul's avatar
Paul committed
100
        if(this->empty())
Paul's avatar
Paul committed
101
            return m_data;
Paul's avatar
Paul committed
102
        else
Paul's avatar
Paul committed
103
            return m_data + this->size();
Paul's avatar
Paul committed
104
105
106
107
    }

    friend bool operator==(const tensor_view<T>& x, const tensor_view<T>& y)
    {
Paul's avatar
Paul committed
108
        if(x.m_shape == y.m_shape)
Paul's avatar
Paul committed
109
        {
Paul's avatar
Paul committed
110
            for(std::size_t i = 0; i < x.m_shape.elements(); i++)
Paul's avatar
Paul committed
111
            {
Paul's avatar
Paul committed
112
113
                if(!float_equal(x[i], y[i]))
                    return false;
Paul's avatar
Paul committed
114
115
116
117
118
119
            }
            return true;
        }
        return false;
    }

Paul's avatar
Paul committed
120
    friend bool operator!=(const tensor_view<T>& x, const tensor_view<T>& y) { return !(x == y); }
Paul's avatar
Paul committed
121

Paul's avatar
Paul committed
122
123
124
125
126
    friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
    {
        if(!x.empty())
        {
            os << x.front();
Paul's avatar
Paul committed
127
            for(std::size_t i = 1; i < x.m_shape.elements(); i++)
Paul's avatar
Paul committed
128
            {
Paul's avatar
Paul committed
129
                os << ", " << x.m_data[x.m_shape.index(i)];
Paul's avatar
Paul committed
130
131
132
133
134
            }
        }
        return os;
    }

Paul's avatar
Paul committed
135
    private:
Paul's avatar
Paul committed
136
137
    T* m_data;
    shape m_shape;
Paul's avatar
Paul committed
138
139
};

Paul's avatar
Paul committed
140
template <class T>
Paul's avatar
Paul committed
141
142
143
144
145
146
147
148
tensor_view<T> make_view(shape s, T* data)
{
    return {s, data};
}

} // namespace rtg

#endif