static_buffer.hpp 7.7 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

Anthony Chang's avatar
Anthony Chang committed
4
#pragma once
5
6
7
8
9

#include "statically_indexed_array.hpp"

namespace ck {

10
// static buffer for scalar
11
template <AddressSpaceEnum AddressSpace,
Chao Liu's avatar
Chao Liu committed
12
          typename S_,
13
          index_t N,
14
          bool InvalidElementUseNumericalZeroValue> // TODO remove this bool, no longer needed
Chao Liu's avatar
Chao Liu committed
15
struct StaticBuffer : public StaticallyIndexedArray<remove_cvref_t<S_>, N>
16
{
Chao Liu's avatar
Chao Liu committed
17
18
19
    using S    = remove_cvref_t<S_>;
    using type = S;
    using base = StaticallyIndexedArray<S, N>;
20
21
22

    __host__ __device__ constexpr StaticBuffer() : base{} {}

Anthony Chang's avatar
Anthony Chang committed
23
24
25
26
27
28
29
30
31
    template <typename... Ys>
    __host__ __device__ constexpr StaticBuffer& operator=(const Tuple<Ys...>& y)
    {
        static_assert(base::Size() == sizeof...(Ys), "wrong! size not the same");
        StaticBuffer& x = *this;
        static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y[i]; });
        return x;
    }

32
    __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
33

Chao Liu's avatar
Chao Liu committed
34
35
    __host__ __device__ static constexpr index_t Size() { return N; }

36
37
38
39
    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }

    __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }

Chao Liu's avatar
Chao Liu committed
40
    // read access to scalar
41
    template <index_t I>
Chao Liu's avatar
Chao Liu committed
42
    __host__ __device__ constexpr const S& operator[](Number<I> i) const
43
    {
44
        return base::operator[](i);
45
46
    }

Chao Liu's avatar
Chao Liu committed
47
    // write access to scalar
48
    template <index_t I>
Chao Liu's avatar
Chao Liu committed
49
    __host__ __device__ constexpr S& operator()(Number<I> i)
50
    {
51
        return base::operator()(i);
52
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
53

Chao Liu's avatar
Chao Liu committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    // Get a vector (type X)
    // "is" is offset of S, not X.
    // "is" should be aligned to X
    template <typename X_,
              index_t Is,
              typename enable_if<has_same_scalar_type<S, X_>::value, bool>::type = false>
    __host__ __device__ constexpr remove_reference_t<X_> GetAsType(Number<Is> is) const
    {
        using X = remove_cvref_t<X_>;

        constexpr index_t kSPerX = scalar_type<X>::vector_size;

        static_assert(Is % kSPerX == 0, "wrong! \"Is\" should be aligned to X");

        vector_type<S, kSPerX> vx;

        static_for<0, kSPerX, 1>{}(
            [&](auto j) { vx.template AsType<S>()(j) = base::operator[](is + j); });

        return vx.template AsType<X>().template At<0>();
    }

    // Set a vector (type X)
    // "is" is offset of S, not X.
    // "is" should be aligned to X
    template <typename X_,
              index_t Is,
              typename enable_if<has_same_scalar_type<S, X_>::value, bool>::type = false>
    __host__ __device__ constexpr void SetAsType(Number<Is> is, X_ x)
    {
        using X = remove_cvref_t<X_>;

        constexpr index_t kSPerX = scalar_type<X>::vector_size;

        static_assert(Is % kSPerX == 0, "wrong! \"Is\" should be aligned to X");

        const vector_type<S, kSPerX> vx{x};

        static_for<0, kSPerX, 1>{}(
            [&](auto j) { base::operator()(is + j) = vx.template AsType<S>()[j]; });
    }

    __host__ __device__ void Initialize(const S& x)
Jianfeng Yan's avatar
Jianfeng Yan committed
97
    {
Chao Liu's avatar
Chao Liu committed
98
        static_for<0, N, 1>{}([&](auto i) { operator()(i) = S{x}; });
Jianfeng Yan's avatar
Jianfeng Yan committed
99
    }
Anthony Chang's avatar
Anthony Chang committed
100

Chao Liu's avatar
Chao Liu committed
101
102
103
104
105
106
107
108
109
    // FIXME: deprecated
    __host__ __device__ void Clear() { Initialize(0); }

    // FIXME: deprecated
    __host__ __device__ constexpr StaticBuffer& operator=(const S& v)
    {
        Initialize(v);
        return *this;
    }
110
111
};

112
// static buffer for vector
113
template <AddressSpaceEnum AddressSpace,
114
115
116
117
118
119
120
          typename S,
          index_t NumOfVector,
          index_t ScalarPerVector,
          bool InvalidElementUseNumericalZeroValue, // TODO remove this bool, no longer needed,
          typename enable_if<is_scalar_type<S>::value, bool>::type = false>
struct StaticBufferTupleOfVector
    : public StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>
121
{
122
123
124
125
126
    using V    = typename vector_type<S, ScalarPerVector>::type;
    using base = StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>;

    static constexpr auto s_per_v   = Number<ScalarPerVector>{};
    static constexpr auto num_of_v_ = Number<NumOfVector>{};
Anthony Chang's avatar
Anthony Chang committed
127
    static constexpr auto s_per_buf = s_per_v * num_of_v_;
128

129
    __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {}
130

131
    __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
132

133
    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
134

135
    __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
136

Anthony Chang's avatar
Anthony Chang committed
137
138
    __host__ __device__ static constexpr index_t Size() { return s_per_buf; };

139
140
141
142
    // Get S
    // i is offset of S
    template <index_t I>
    __host__ __device__ constexpr const S& operator[](Number<I> i) const
143
    {
144
145
        constexpr auto i_v = i / s_per_v;
        constexpr auto i_s = i % s_per_v;
146

147
        return base::operator[](i_v).template AsType<S>()[i_s];
148
149
    }

150
151
    // Set S
    // i is offset of S
152
    template <index_t I>
153
    __host__ __device__ constexpr S& operator()(Number<I> i)
154
    {
155
156
        constexpr auto i_v = i / s_per_v;
        constexpr auto i_s = i % s_per_v;
157

158
        return base::operator()(i_v).template AsType<S>()(i_s);
159
160
    }

161
162
163
164
165
166
    // Get X
    // i is offset of S, not X. i should be aligned to X
    template <typename X,
              index_t I,
              typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
    __host__ __device__ constexpr auto GetAsType(Number<I> i) const
167
    {
168
169
170
171
172
173
174
        constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};

        static_assert(s_per_v % s_per_x == 0, "wrong! V must  one or multiple X");
        static_assert(i % s_per_x == 0, "wrong!");

        constexpr auto i_v = i / s_per_v;
        constexpr auto i_x = (i % s_per_v) / s_per_x;
175

176
        return base::operator[](i_v).template AsType<X>()[i_x];
177
178
    }

179
180
181
182
183
184
    // Set X
    // i is offset of S, not X. i should be aligned to X
    template <typename X,
              index_t I,
              typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
    __host__ __device__ constexpr void SetAsType(Number<I> i, X x)
185
    {
186
187
188
189
190
191
192
193
194
        constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};

        static_assert(s_per_v % s_per_x == 0, "wrong! V must contain one or multiple X");
        static_assert(i % s_per_x == 0, "wrong!");

        constexpr auto i_v = i / s_per_v;
        constexpr auto i_x = (i % s_per_v) / s_per_x;

        base::operator()(i_v).template AsType<X>()(i_x) = x;
195
196
    }

197
198
    // Get read access to vector_type V
    // i is offset of S, not V. i should be aligned to V
199
    template <index_t I>
200
    __host__ __device__ constexpr const auto& GetVectorTypeReference(Number<I> i) const
201
    {
202
203
204
205
206
        static_assert(i % s_per_v == 0, "wrong!");

        constexpr auto i_v = i / s_per_v;

        return base::operator[](i_v);
207
208
    }

209
210
    // Get write access to vector_type V
    // i is offset of S, not V. i should be aligned to V
211
    template <index_t I>
212
    __host__ __device__ constexpr auto& GetVectorTypeReference(Number<I> i)
213
    {
214
        static_assert(i % s_per_v == 0, "wrong!");
215

216
        constexpr auto i_v = i / s_per_v;
217

218
219
        return base::operator()(i_v);
    }
zjing14's avatar
zjing14 committed
220
221
222

    __host__ __device__ void Clear()
    {
Jianfeng Yan's avatar
Jianfeng Yan committed
223
        constexpr index_t NumScalars = NumOfVector * ScalarPerVector;
zjing14's avatar
zjing14 committed
224

Jianfeng Yan's avatar
Jianfeng Yan committed
225
        static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); });
zjing14's avatar
zjing14 committed
226
    }
227
228
};

229
template <AddressSpaceEnum AddressSpace, typename T, index_t N>
230
231
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
232
    return StaticBuffer<AddressSpace, T, N, true>{};
233
234
}

235
236
237
238
239
240
template <AddressSpaceEnum AddressSpace, typename T, long_index_t N>
__host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
{
    return StaticBuffer<AddressSpace, T, N, true>{};
}

241
} // namespace ck