static_buffer.hpp 5.54 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
5
6
7
8
9
10
#ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP

#include "statically_indexed_array.hpp"

namespace ck {

11
// static buffer for scalar
12
template <AddressSpaceEnum AddressSpace,
13
14
          typename T,
          index_t N,
15
          bool InvalidElementUseNumericalZeroValue> // TODO remove this bool, no longer needed
16
17
18
19
20
21
22
struct StaticBuffer : public StaticallyIndexedArray<T, N>
{
    using type = T;
    using base = StaticallyIndexedArray<T, N>;

    __host__ __device__ constexpr StaticBuffer() : base{} {}

23
    __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
24

25
26
27
28
29
    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }

    __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }

    // read access
30
    template <index_t I>
31
    __host__ __device__ constexpr const T& operator[](Number<I> i) const
32
    {
33
        return base::operator[](i);
34
35
    }

36
    // write access
37
    template <index_t I>
38
    __host__ __device__ constexpr T& operator()(Number<I> i)
39
    {
40
        return base::operator()(i);
41
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
42
43
44
45
46

    __host__ __device__ void Clear()
    {
        static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{0}; });
    }
47
48
};

49
#ifndef CK_NOGPU
50
// static buffer for vector
51
template <AddressSpaceEnum AddressSpace,
52
53
54
55
56
57
58
          typename S,
          index_t NumOfVector,
          index_t ScalarPerVector,
          bool InvalidElementUseNumericalZeroValue, // TODO remove this bool, no longer needed,
          typename enable_if<is_scalar_type<S>::value, bool>::type = false>
struct StaticBufferTupleOfVector
    : public StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>
59
{
60
61
62
63
64
    using V    = typename vector_type<S, ScalarPerVector>::type;
    using base = StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>;

    static constexpr auto s_per_v   = Number<ScalarPerVector>{};
    static constexpr auto num_of_v_ = Number<NumOfVector>{};
65

66
    __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {}
67

68
    __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
69

70
    __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
71

72
    __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
73

74
75
76
77
    // Get S
    // i is offset of S
    template <index_t I>
    __host__ __device__ constexpr const S& operator[](Number<I> i) const
78
    {
79
80
        constexpr auto i_v = i / s_per_v;
        constexpr auto i_s = i % s_per_v;
81

82
        return base::operator[](i_v).template AsType<S>()[i_s];
83
84
    }

85
86
    // Set S
    // i is offset of S
87
    template <index_t I>
88
    __host__ __device__ constexpr S& operator()(Number<I> i)
89
    {
90
91
        constexpr auto i_v = i / s_per_v;
        constexpr auto i_s = i % s_per_v;
92

93
        return base::operator()(i_v).template AsType<S>()(i_s);
94
95
    }

96
97
98
99
100
101
    // Get X
    // i is offset of S, not X. i should be aligned to X
    template <typename X,
              index_t I,
              typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
    __host__ __device__ constexpr auto GetAsType(Number<I> i) const
102
    {
103
104
105
106
107
108
109
        constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};

        static_assert(s_per_v % s_per_x == 0, "wrong! V must  one or multiple X");
        static_assert(i % s_per_x == 0, "wrong!");

        constexpr auto i_v = i / s_per_v;
        constexpr auto i_x = (i % s_per_v) / s_per_x;
110

111
        return base::operator[](i_v).template AsType<X>()[i_x];
112
113
    }

114
115
116
117
118
119
    // Set X
    // i is offset of S, not X. i should be aligned to X
    template <typename X,
              index_t I,
              typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
    __host__ __device__ constexpr void SetAsType(Number<I> i, X x)
120
    {
121
122
123
124
125
126
127
128
129
        constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};

        static_assert(s_per_v % s_per_x == 0, "wrong! V must contain one or multiple X");
        static_assert(i % s_per_x == 0, "wrong!");

        constexpr auto i_v = i / s_per_v;
        constexpr auto i_x = (i % s_per_v) / s_per_x;

        base::operator()(i_v).template AsType<X>()(i_x) = x;
130
131
    }

132
133
    // Get read access to vector_type V
    // i is offset of S, not V. i should be aligned to V
134
    template <index_t I>
135
    __host__ __device__ constexpr const auto& GetVectorTypeReference(Number<I> i) const
136
    {
137
138
139
140
141
        static_assert(i % s_per_v == 0, "wrong!");

        constexpr auto i_v = i / s_per_v;

        return base::operator[](i_v);
142
143
    }

144
145
    // Get write access to vector_type V
    // i is offset of S, not V. i should be aligned to V
146
    template <index_t I>
147
    __host__ __device__ constexpr auto& GetVectorTypeReference(Number<I> i)
148
    {
149
        static_assert(i % s_per_v == 0, "wrong!");
150

151
        constexpr auto i_v = i / s_per_v;
152

153
154
        return base::operator()(i_v);
    }
zjing14's avatar
zjing14 committed
155
156
157

    __host__ __device__ void Clear()
    {
Jianfeng Yan's avatar
Jianfeng Yan committed
158
        constexpr index_t NumScalars = NumOfVector * ScalarPerVector;
zjing14's avatar
zjing14 committed
159

Jianfeng Yan's avatar
Jianfeng Yan committed
160
        static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); });
zjing14's avatar
zjing14 committed
161
    }
162
};
163
#endif
164

165
template <AddressSpaceEnum AddressSpace, typename T, index_t N>
166
167
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
168
    return StaticBuffer<AddressSpace, T, N, true>{};
169
170
}

171
172
173
174
175
176
template <AddressSpaceEnum AddressSpace, typename T, long_index_t N>
__host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
{
    return StaticBuffer<AddressSpace, T, N, true>{};
}

177
178
} // namespace ck
#endif