"docs/ZH_CN/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "62f4464d4ca466a63d9e797fbcc28fe8fa2c1c44"
static_buffer.hpp 5.32 KB
Newer Older
1
2
3
4
5
6
7
#ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP

#include "statically_indexed_array.hpp"

namespace ck {

8
// static buffer for scalar
9
template <AddressSpaceEnum AddressSpace,
10
11
          typename T,
          index_t N,
12
          bool InvalidElementUseNumericalZeroValue> // TODO remove this bool, no longer needed
13
14
15
16
17
18
19
struct StaticBuffer : public StaticallyIndexedArray<T, N>
{
    using type = T;
    using base = StaticallyIndexedArray<T, N>;

    __host__ __device__ constexpr StaticBuffer() : base{} {}

20
    __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
21

22
23
24
25
26
    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }

    __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }

    // read access
27
    template <index_t I>
28
    __host__ __device__ constexpr const T& operator[](Number<I> i) const
29
    {
30
        return base::operator[](i);
31
32
    }

33
    // write access
34
    template <index_t I>
35
    __host__ __device__ constexpr T& operator()(Number<I> i)
36
    {
37
        return base::operator()(i);
38
    }
39
40
};

41
#ifndef CK_NOGPU
42
// static buffer for vector
43
template <AddressSpaceEnum AddressSpace,
44
45
46
47
48
49
50
          typename S,
          index_t NumOfVector,
          index_t ScalarPerVector,
          bool InvalidElementUseNumericalZeroValue, // TODO remove this bool, no longer needed,
          typename enable_if<is_scalar_type<S>::value, bool>::type = false>
struct StaticBufferTupleOfVector
    : public StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>
51
{
52
53
54
55
56
    using V    = typename vector_type<S, ScalarPerVector>::type;
    using base = StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>;

    static constexpr auto s_per_v   = Number<ScalarPerVector>{};
    static constexpr auto num_of_v_ = Number<NumOfVector>{};
57

58
    __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {}
59

60
    __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
61

62
    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
63

64
    __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
65

66
67
68
69
    // Get S
    // i is offset of S
    template <index_t I>
    __host__ __device__ constexpr const S& operator[](Number<I> i) const
70
    {
71
72
        constexpr auto i_v = i / s_per_v;
        constexpr auto i_s = i % s_per_v;
73

74
        return base::operator[](i_v).template AsType<S>()[i_s];
75
76
    }

77
78
    // Set S
    // i is offset of S
79
    template <index_t I>
80
    __host__ __device__ constexpr S& operator()(Number<I> i)
81
    {
82
83
        constexpr auto i_v = i / s_per_v;
        constexpr auto i_s = i % s_per_v;
84

85
        return base::operator()(i_v).template AsType<S>()(i_s);
86
87
    }

88
89
90
91
92
93
    // Get X
    // i is offset of S, not X. i should be aligned to X
    template <typename X,
              index_t I,
              typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
    __host__ __device__ constexpr auto GetAsType(Number<I> i) const
94
    {
95
96
97
98
99
100
101
        constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};

        static_assert(s_per_v % s_per_x == 0, "wrong! V must  one or multiple X");
        static_assert(i % s_per_x == 0, "wrong!");

        constexpr auto i_v = i / s_per_v;
        constexpr auto i_x = (i % s_per_v) / s_per_x;
102

103
        return base::operator[](i_v).template AsType<X>()[i_x];
104
105
    }

106
107
108
109
110
111
    // Set X
    // i is offset of S, not X. i should be aligned to X
    template <typename X,
              index_t I,
              typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
    __host__ __device__ constexpr void SetAsType(Number<I> i, X x)
112
    {
113
114
115
116
117
118
119
120
121
        constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};

        static_assert(s_per_v % s_per_x == 0, "wrong! V must contain one or multiple X");
        static_assert(i % s_per_x == 0, "wrong!");

        constexpr auto i_v = i / s_per_v;
        constexpr auto i_x = (i % s_per_v) / s_per_x;

        base::operator()(i_v).template AsType<X>()(i_x) = x;
122
123
    }

124
125
    // Get read access to vector_type V
    // i is offset of S, not V. i should be aligned to V
126
    template <index_t I>
127
    __host__ __device__ constexpr const auto& GetVectorTypeReference(Number<I> i) const
128
    {
129
130
131
132
133
        static_assert(i % s_per_v == 0, "wrong!");

        constexpr auto i_v = i / s_per_v;

        return base::operator[](i_v);
134
135
    }

136
137
    // Get write access to vector_type V
    // i is offset of S, not V. i should be aligned to V
138
    template <index_t I>
139
    __host__ __device__ constexpr auto& GetVectorTypeReference(Number<I> i)
140
    {
141
        static_assert(i % s_per_v == 0, "wrong!");
142

143
        constexpr auto i_v = i / s_per_v;
144

145
146
        return base::operator()(i_v);
    }
zjing14's avatar
zjing14 committed
147
148
149
150
151
152
153

    __host__ __device__ void Clear()
    {
        const index_t numScalars = NumOfVector * ScalarPerVector;

        static_for<0, Number<numScalars>{}, 1>{}([&](auto i) { SetAsType(i, S{0}); });
    }
154
};
155
#endif
156

157
template <AddressSpaceEnum AddressSpace, typename T, index_t N>
158
159
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
160
    return StaticBuffer<AddressSpace, T, N, true>{};
161
162
}

163
164
165
166
167
168
template <AddressSpaceEnum AddressSpace, typename T, long_index_t N>
__host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
{
    return StaticBuffer<AddressSpace, T, N, true>{};
}

169
170
} // namespace ck
#endif