tensor_view.hpp 3.06 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#ifndef RTG_GUARD_TENSOR_VIEW_HPP
#define RTG_GUARD_TENSOR_VIEW_HPP

#include <rtg/shape.hpp>
Paul's avatar
Paul committed
5
#include <rtg/float_equal.hpp>
Paul's avatar
Paul committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

#include <iostream>

namespace rtg {

template<class T>
struct tensor_view
{
    tensor_view()
    : data_(nullptr), shape_()
    {}
    tensor_view(shape s, T* d)
    : data_(d), shape_(s)
    {}

    const shape& get_shape() const
    {
        return this->shape_;
    }

    bool empty() const
    {
        return data_ == nullptr || shape_.lens().size() == 0;
    }

    std::size_t size() const
    {
        return shape_.elements();
    }

    T* data()
    {
        return this->data_;
    }

    const T* data() const
    {
        return this->data_;
    }

    template<class... Ts>
    const T& operator()(Ts... xs) const
    {
        return data_[shape_.index({xs...})];
    }

    template<class... Ts>
    T& operator()(Ts... xs)
    {
        return data_[shape_.index({xs...})];
    }

    T& operator[](std::size_t i)
    {
        assert(!this->empty() && i < this->size());
        return data_[shape_.index(i)];
    }

    const T& operator[](std::size_t i) const
    {
        assert(!this->empty() && i < this->size());
        return data_[shape_.index(i)];
    }

    T& front()
    {
        assert(!this->empty());
        return data_[0];
    }

    const T& front() const
    {
        assert(!this->empty());
        return data_[0];
    }

    T& back()
    {
        assert(!this->empty());
        return data_[shape_.index(this->size()-1)];
    }

    const T& back() const
    {
        assert(!this->empty());
        return data_[shape_.index(this->size()-1)];
    }

    // TODO: Add iterators so it can handle nonpacked tensors
    T* begin()
    {
        assert(this->shape_.packed());
        return data_;
    }

    T* end()
    {
        assert(this->shape_.packed());
        if(this->empty()) return data_;
        else return data_+this->size();
    }

    const T* begin() const
    {
        assert(this->shape_.packed());
        return data_;
    }

    const T* end() const
    {
        assert(this->shape_.packed());
        if(this->empty()) return data_;
        else return data_+this->size();
    }

    friend bool operator==(const tensor_view<T>& x, const tensor_view<T>& y)
    {
        if(x.shape_ == y.shape_)
        {
            for(std::size_t i = 0;i < x.shape_.elements();i++)
            {
Paul's avatar
Paul committed
127
                if(!float_equal(x[i], y[i])) return false;
Paul's avatar
Paul committed
128
129
130
131
132
133
134
135
136
137
138
            }
            return true;
        }
        return false;
    }

    friend bool operator!=(const tensor_view<T>& x, const tensor_view<T>& y)
    {
        return !(x == y);
    }

Paul's avatar
Paul committed
139
140
141
142
143
144
145
146
147
148
149
150
151
    friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
    {
        if(!x.empty())
        {
            os << x.front();
            for(std::size_t i = 1;i < x.shape_.elements();i++)
            {
                os << ", " << x.data_[x.shape_.index(i)];
            }
        }
        return os;
    }

Paul's avatar
Paul committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
private:
    T* data_;
    shape shape_;
};

template<class T>
tensor_view<T> make_view(shape s, T* data)
{
    return {s, data};
}

} // namespace rtg

#endif