dynamic_buffer_cpu.hpp 4.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
#ifndef CK_BUFFER_CPU_HPP
#define CK_BUFFER_CPU_HPP

#include "config.hpp"
#include "enable_if.hpp"
#include "data_type_cpu.hpp"

namespace ck {
namespace cpu {

11
template <AddressSpaceEnum BufferAddressSpace,
12
13
14
15
16
17
18
19
          typename T,
          typename ElementSpaceSize,
          bool InvalidElementUseNumericalZeroValue>
struct DynamicBuffer
{
    using type = T;

    static_assert(BufferAddressSpace ==
20
                  AddressSpaceEnum::Global); // only valid for global address space on cpu
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

    T* p_data_;
    ElementSpaceSize element_space_size_;
    T invalid_element_value_ = T{0};

    constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
        : p_data_{p_data}, element_space_size_{element_space_size}
    {
    }

    constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size, T invalid_element_value)
        : p_data_{p_data},
          element_space_size_{element_space_size},
          invalid_element_value_{invalid_element_value}
    {
    }

38
    static constexpr AddressSpaceEnum GetAddressSpace() { return BufferAddressSpace; }
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

    constexpr const T& operator[](index_t i) const { return p_data_[i]; }

    constexpr T& operator()(index_t i) { return p_data_[i]; }

    // X should be data_type::type, not directly data_type
    template <typename X,
              typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
                                         typename scalar_type<remove_cvref_t<T>>::type>::value,
                                 bool>::type = false>
    constexpr auto Get(index_t i, bool is_valid_element) const
    {
        if constexpr(InvalidElementUseNumericalZeroValue)
        {
            X v;
            if(is_valid_element)
                load_vector(v, &p_data_[i]);
            else
                clear_vector(v);
            return v;
        }
        else
        {
            X v;
            if(is_valid_element)
                load_vector(v, &p_data_[i]);
            else
                set_vector(v, invalid_element_value_);
            return v;
        }
    }

71
    template <InMemoryDataOperationEnum Op,
72
73
74
75
76
77
              typename X,
              typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
                                         typename scalar_type<remove_cvref_t<T>>::type>::value,
                                 bool>::type = false>
    void Update(index_t i, bool is_valid_element, const X& x)
    {
78
        if constexpr(Op == InMemoryDataOperationEnum::Set)
79
80
81
        {
            this->template Set<X>(i, is_valid_element, x);
        }
82
        else if constexpr(Op == InMemoryDataOperationEnum::Add)
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        {
            auto tmp = this->template Get<X>(i, is_valid_element);
            this->template Set<X>(i, is_valid_element, x + tmp);
        }
    }

    template <typename X,
              typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
                                         typename scalar_type<remove_cvref_t<T>>::type>::value,
                                 bool>::type = false>
    void Set(index_t i, bool is_valid_element, const X& x)
    {
        // X contains multiple T
        constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;

        constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;

        static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
                      "wrong! X need to be multiple T");

        if(is_valid_element)
        {
            store_vector(x, &p_data_[i]);
        }
    }

    static constexpr bool IsStaticBuffer() { return false; }

    static constexpr bool IsDynamicBuffer() { return true; }
};

114
template <AddressSpaceEnum BufferAddressSpace, typename T, typename ElementSpaceSize>
115
116
117
118
119
120
constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{
    return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
}

template <
121
    AddressSpaceEnum BufferAddressSpace,
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    typename T,
    typename ElementSpaceSize,
    typename X,
    typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
constexpr auto
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value)
{
    return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
        p, element_space_size, invalid_element_value};
}

} // namespace cpu
} // namespace ck
#endif