"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "b16599cdf2f25ef024c409db3a3c80bf78707b65"
static_buffer.hpp 4.57 KB
Newer Older
1
2
3
4
5
6
7
#ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP

#include "statically_indexed_array.hpp"

namespace ck {

8
9
10
11
template <AddressSpaceEnum_t BufferAddressSpace,
          typename T,
          index_t N,
          bool InvalidElementUseNumericalZeroValue>
12
13
14
15
16
struct StaticBuffer : public StaticallyIndexedArray<T, N>
{
    using type = T;
    using base = StaticallyIndexedArray<T, N>;

17
18
    T invalid_element_value_ = T{0};

19
20
    __host__ __device__ constexpr StaticBuffer() : base{} {}

21
22
23
24
25
    __host__ __device__ constexpr StaticBuffer(T invalid_element_value)
        : base{}, invalid_element_value_{invalid_element_value}
    {
    }

26
    __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
27
28
29
30
    {
        return BufferAddressSpace;
    }

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    template <index_t I>
    __host__ __device__ constexpr auto Get(Number<I> i, bool is_valid_element) const
    {
        if constexpr(InvalidElementUseNumericalZeroValue)
        {
            return is_valid_element ? At(i) : T{0};
        }
        else
        {
            return is_valid_element ? At(i) : invalid_element_value_;
        }
    }

    template <index_t I>
    __host__ __device__ void Set(Number<I> i, bool is_valid_element, const T& x)
    {
        if(is_valid_element)
        {
            At(i) = x;
        }
    }

53
54
55
56
57
    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }

    __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
template <AddressSpaceEnum_t BufferAddressSpace,
          typename T,
          index_t N,
          bool InvalidElementUseNumericalZeroValue>
struct StaticBufferV2 : public StaticallyIndexedArray<T, N>
{
    using type = T;
    using base = StaticallyIndexedArray<T, N>;

    using VecBaseType = typename T::d1_t;

    __host__ __device__ static constexpr index_t GetVectorSize()
    {
        return sizeof(typename T::type) / sizeof(VecBaseType);
    }

    static constexpr index_t vector_size = GetVectorSize();

    VecBaseType invalid_element_value_ = VecBaseType{0};

    T invalid_vec_value_ = T{0};

    __host__ __device__ constexpr StaticBufferV2() : base{} {}

    __host__ __device__ constexpr StaticBufferV2(VecBaseType invalid_element_value)
        : base{},
          invalid_vec_value_{invalid_element_value},
          invalid_element_value_{invalid_element_value}
    {
    }

    __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
    {
        return BufferAddressSpace;
    }

    template <index_t I>
    __host__ __device__ constexpr auto& GetVector(Number<I> vec_id)
    {
        return this->At(vec_id);
    }

    template <index_t I>
    __host__ __device__ constexpr const auto& GetVector(Number<I> vec_id) const
    {
        return this->At(vec_id);
    }

    template <index_t I>
    __host__ __device__ constexpr auto& GetElement(Number<I> i, bool)
    {
        constexpr auto vec_id  = Number<i / vector_size>{};
        constexpr auto vec_off = Number<i % vector_size>{};

        return this->At(vec_id).template AsType<VecBaseType>()(vec_off);
    }

    template <index_t I>
    __host__ __device__ constexpr auto GetElement(Number<I> i, bool is_valid_element) const
    {
        constexpr auto vec_id  = Number<i / vector_size>{};
        constexpr auto vec_off = Number<i % vector_size>{};

        if constexpr(InvalidElementUseNumericalZeroValue)
        {
            return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
                                    : VecBaseType{0};
        }
        else
        {
            return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
                                    : invalid_element_value_;
        }
    }

    template <index_t I>
    __host__ __device__ constexpr auto operator[](Number<I> i) const
    {
        return GetElement(i, true);
    }

    template <index_t I>
    __host__ __device__ constexpr auto& operator()(Number<I> i)
    {
        return GetElement(i, true);
    }

    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }

    __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};

150
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
151
152
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
153
154
155
156
157
158
159
    return StaticBuffer<BufferAddressSpace, T, N, true>{};
}

template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>, T invalid_element_value)
{
    return StaticBuffer<BufferAddressSpace, T, N, false>{invalid_element_value};
160
161
162
163
}

} // namespace ck
#endif