tensor_view.hpp 3.04 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#ifndef RTG_GUARD_TENSOR_VIEW_HPP
#define RTG_GUARD_TENSOR_VIEW_HPP

#include <rtg/shape.hpp>
Paul's avatar
Paul committed
5
#include <rtg/float_equal.hpp>
Paul's avatar
Paul committed
6
7
8
9
10

#include <iostream>

namespace rtg {

Paul's avatar
Paul committed
11
template <class T>
Paul's avatar
Paul committed
12
13
struct tensor_view
{
Paul's avatar
Paul committed
14
15
    tensor_view() : data_(nullptr), shape_() {}
    tensor_view(shape s, T* d) : data_(d), shape_(s) {}
Paul's avatar
Paul committed
16

Paul's avatar
Paul committed
17
    const shape& get_shape() const { return this->shape_; }
Paul's avatar
Paul committed
18

Paul's avatar
Paul committed
19
    bool empty() const { return data_ == nullptr || shape_.lens().size() == 0; }
Paul's avatar
Paul committed
20

Paul's avatar
Paul committed
21
    std::size_t size() const { return shape_.elements(); }
Paul's avatar
Paul committed
22

Paul's avatar
Paul committed
23
    T* data() { return this->data_; }
Paul's avatar
Paul committed
24

Paul's avatar
Paul committed
25
26
27
    const T* data() const { return this->data_; }

    template <class... Ts>
Paul's avatar
Paul committed
28
29
30
31
32
    const T& operator()(Ts... xs) const
    {
        return data_[shape_.index({xs...})];
    }

Paul's avatar
Paul committed
33
    template <class... Ts>
Paul's avatar
Paul committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    T& operator()(Ts... xs)
    {
        return data_[shape_.index({xs...})];
    }

    T& operator[](std::size_t i)
    {
        assert(!this->empty() && i < this->size());
        return data_[shape_.index(i)];
    }

    const T& operator[](std::size_t i) const
    {
        assert(!this->empty() && i < this->size());
        return data_[shape_.index(i)];
    }

    T& front()
    {
        assert(!this->empty());
        return data_[0];
    }

    const T& front() const
    {
        assert(!this->empty());
        return data_[0];
    }

    T& back()
    {
        assert(!this->empty());
Paul's avatar
Paul committed
66
        return data_[shape_.index(this->size() - 1)];
Paul's avatar
Paul committed
67
68
69
70
71
    }

    const T& back() const
    {
        assert(!this->empty());
Paul's avatar
Paul committed
72
        return data_[shape_.index(this->size() - 1)];
Paul's avatar
Paul committed
73
74
75
76
77
78
79
80
81
82
83
84
    }

    // TODO: Add iterators so it can handle nonpacked tensors
    T* begin()
    {
        assert(this->shape_.packed());
        return data_;
    }

    T* end()
    {
        assert(this->shape_.packed());
Paul's avatar
Paul committed
85
86
87
88
        if(this->empty())
            return data_;
        else
            return data_ + this->size();
Paul's avatar
Paul committed
89
90
91
92
93
94
95
96
97
98
99
    }

    const T* begin() const
    {
        assert(this->shape_.packed());
        return data_;
    }

    const T* end() const
    {
        assert(this->shape_.packed());
Paul's avatar
Paul committed
100
101
102
103
        if(this->empty())
            return data_;
        else
            return data_ + this->size();
Paul's avatar
Paul committed
104
105
106
107
108
109
    }

    friend bool operator==(const tensor_view<T>& x, const tensor_view<T>& y)
    {
        if(x.shape_ == y.shape_)
        {
Paul's avatar
Paul committed
110
            for(std::size_t i = 0; i < x.shape_.elements(); i++)
Paul's avatar
Paul committed
111
            {
Paul's avatar
Paul committed
112
113
                if(!float_equal(x[i], y[i]))
                    return false;
Paul's avatar
Paul committed
114
115
116
117
118
119
            }
            return true;
        }
        return false;
    }

Paul's avatar
Paul committed
120
    friend bool operator!=(const tensor_view<T>& x, const tensor_view<T>& y) { return !(x == y); }
Paul's avatar
Paul committed
121

Paul's avatar
Paul committed
122
123
124
125
126
    friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
    {
        if(!x.empty())
        {
            os << x.front();
Paul's avatar
Paul committed
127
            for(std::size_t i = 1; i < x.shape_.elements(); i++)
Paul's avatar
Paul committed
128
129
130
131
132
133
134
            {
                os << ", " << x.data_[x.shape_.index(i)];
            }
        }
        return os;
    }

Paul's avatar
Paul committed
135
    private:
Paul's avatar
Paul committed
136
137
138
139
    T* data_;
    shape shape_;
};

Paul's avatar
Paul committed
140
template <class T>
Paul's avatar
Paul committed
141
142
143
144
145
146
147
148
tensor_view<T> make_view(shape s, T* data)
{
    return {s, data};
}

} // namespace rtg

#endif