static_buffer.hpp 6.35 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

Anthony Chang's avatar
Anthony Chang committed
4
#pragma once
5
6
7
8
9

#include "statically_indexed_array.hpp"

namespace ck {

10
// static buffer for scalar
11
template <AddressSpaceEnum AddressSpace,
12
13
          typename T,
          index_t N,
14
          bool InvalidElementUseNumericalZeroValue> // TODO remove this bool, no longer needed
15
16
17
18
19
20
21
struct StaticBuffer : public StaticallyIndexedArray<T, N>
{
    using type = T;
    using base = StaticallyIndexedArray<T, N>;

    __host__ __device__ constexpr StaticBuffer() : base{} {}

Anthony Chang's avatar
Anthony Chang committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    template <typename... Ys>
    __host__ __device__ constexpr StaticBuffer& operator=(const Tuple<Ys...>& y)
    {
        static_assert(base::Size() == sizeof...(Ys), "wrong! size not the same");
        StaticBuffer& x = *this;
        static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y[i]; });
        return x;
    }

    __host__ __device__ constexpr StaticBuffer& operator=(const T& y)
    {
        StaticBuffer& x = *this;
        static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y; });
        return x;
    }

38
    __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
39

40
41
42
43
44
    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }

    __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }

    // read access
45
    template <index_t I>
46
    __host__ __device__ constexpr const T& operator[](Number<I> i) const
47
    {
48
        return base::operator[](i);
49
50
    }

51
    // write access
52
    template <index_t I>
53
    __host__ __device__ constexpr T& operator()(Number<I> i)
54
    {
55
        return base::operator()(i);
56
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
57

Anthony Chang's avatar
Anthony Chang committed
58
    __host__ __device__ void Set(T x)
Jianfeng Yan's avatar
Jianfeng Yan committed
59
    {
Anthony Chang's avatar
Anthony Chang committed
60
        static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{x}; });
Jianfeng Yan's avatar
Jianfeng Yan committed
61
    }
Anthony Chang's avatar
Anthony Chang committed
62
63

    __host__ __device__ void Clear() { Set(T{0}); }
64
65
};

66
// static buffer for vector
67
template <AddressSpaceEnum AddressSpace,
68
69
70
71
72
73
74
          typename S,
          index_t NumOfVector,
          index_t ScalarPerVector,
          bool InvalidElementUseNumericalZeroValue, // TODO remove this bool, no longer needed,
          typename enable_if<is_scalar_type<S>::value, bool>::type = false>
struct StaticBufferTupleOfVector
    : public StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>
75
{
76
77
78
79
80
    using V    = typename vector_type<S, ScalarPerVector>::type;
    using base = StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>;

    static constexpr auto s_per_v   = Number<ScalarPerVector>{};
    static constexpr auto num_of_v_ = Number<NumOfVector>{};
Anthony Chang's avatar
Anthony Chang committed
81
    static constexpr auto s_per_buf = s_per_v * num_of_v_;
82

83
    __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {}
84

85
    __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
86

87
    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
88

89
    __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
90

Anthony Chang's avatar
Anthony Chang committed
91
92
    __host__ __device__ static constexpr index_t Size() { return s_per_buf; };

93
94
95
96
    // Get S
    // i is offset of S
    template <index_t I>
    __host__ __device__ constexpr const S& operator[](Number<I> i) const
97
    {
98
99
        constexpr auto i_v = i / s_per_v;
        constexpr auto i_s = i % s_per_v;
100

101
        return base::operator[](i_v).template AsType<S>()[i_s];
102
103
    }

104
105
    // Set S
    // i is offset of S
106
    template <index_t I>
107
    __host__ __device__ constexpr S& operator()(Number<I> i)
108
    {
109
110
        constexpr auto i_v = i / s_per_v;
        constexpr auto i_s = i % s_per_v;
111

112
        return base::operator()(i_v).template AsType<S>()(i_s);
113
    }
letaoqin's avatar
letaoqin committed
114
115
116
117
118
    template <index_t I>
    __host__ __device__ constexpr S& operator()(Number<I> i_v, Number<I> i_s)
    {
        return base::operator()(i_v).template AsType<S>()(i_s);
    }
119

120
121
122
123
124
125
    // Get X
    // i is offset of S, not X. i should be aligned to X
    template <typename X,
              index_t I,
              typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
    __host__ __device__ constexpr auto GetAsType(Number<I> i) const
126
    {
127
128
129
130
131
132
133
        constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};

        static_assert(s_per_v % s_per_x == 0, "wrong! V must  one or multiple X");
        static_assert(i % s_per_x == 0, "wrong!");

        constexpr auto i_v = i / s_per_v;
        constexpr auto i_x = (i % s_per_v) / s_per_x;
134

135
        return base::operator[](i_v).template AsType<X>()[i_x];
136
137
    }

138
139
140
141
142
143
    // Set X
    // i is offset of S, not X. i should be aligned to X
    template <typename X,
              index_t I,
              typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
    __host__ __device__ constexpr void SetAsType(Number<I> i, X x)
144
    {
145
146
147
148
149
150
151
152
153
        constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};

        static_assert(s_per_v % s_per_x == 0, "wrong! V must contain one or multiple X");
        static_assert(i % s_per_x == 0, "wrong!");

        constexpr auto i_v = i / s_per_v;
        constexpr auto i_x = (i % s_per_v) / s_per_x;

        base::operator()(i_v).template AsType<X>()(i_x) = x;
154
155
    }

156
157
    // Get read access to vector_type V
    // i is offset of S, not V. i should be aligned to V
158
    template <index_t I>
159
    __host__ __device__ constexpr const auto& GetVectorTypeReference(Number<I> i) const
160
    {
161
162
163
164
165
        static_assert(i % s_per_v == 0, "wrong!");

        constexpr auto i_v = i / s_per_v;

        return base::operator[](i_v);
166
167
    }

168
169
    // Get write access to vector_type V
    // i is offset of S, not V. i should be aligned to V
170
    template <index_t I>
171
    __host__ __device__ constexpr auto& GetVectorTypeReference(Number<I> i)
172
    {
173
        static_assert(i % s_per_v == 0, "wrong!");
174

175
        constexpr auto i_v = i / s_per_v;
176

177
178
        return base::operator()(i_v);
    }
zjing14's avatar
zjing14 committed
179
180
181

    __host__ __device__ void Clear()
    {
Jianfeng Yan's avatar
Jianfeng Yan committed
182
        constexpr index_t NumScalars = NumOfVector * ScalarPerVector;
zjing14's avatar
zjing14 committed
183

Jianfeng Yan's avatar
Jianfeng Yan committed
184
        static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); });
zjing14's avatar
zjing14 committed
185
    }
186
187
};

188
template <AddressSpaceEnum AddressSpace, typename T, index_t N>
189
190
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
191
    return StaticBuffer<AddressSpace, T, N, true>{};
192
193
}

194
195
196
197
198
199
template <AddressSpaceEnum AddressSpace, typename T, long_index_t N>
__host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
{
    return StaticBuffer<AddressSpace, T, N, true>{};
}

200
} // namespace ck