elementwise.h 7.86 KB
Newer Older
1
2
3
4
5
6
#ifndef __INFINIOP_ELEMENTWISE_H__
#define __INFINIOP_ELEMENTWISE_H__

#include "../operator.h"
#include "../tensor.h"
#include <algorithm>
7
8
#include <cstring>
#include <iostream>
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <memory>
#include <numeric>
#include <vector>

#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE)                                 \
                                                                              \
    namespace op::OP::NAMESPACE {                                             \
    class Descriptor final : public InfiniopDescriptor {                      \
        infiniDtype_t _dtype;                                                 \
        op::elementwise::ElementwiseInfo _info;                               \
        std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
                                                                              \
        Descriptor(                                                           \
            infiniDtype_t dtype,                                              \
            op::elementwise::ElementwiseInfo info,                            \
            op::elementwise::NAMESPACE::DeviceImpl *device_info,              \
            infiniDevice_t device_type,                                       \
            int device_id)                                                    \
            : InfiniopDescriptor{device_type, device_id},                     \
              _dtype(dtype),                                                  \
29
              _info(std::move(info)),                                         \
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
              _device_info(device_info) {}                                    \
                                                                              \
    public:                                                                   \
        ~Descriptor();                                                        \
                                                                              \
        static infiniStatus_t create(                                         \
            infiniopHandle_t handle,                                          \
            Descriptor **desc_ptr,                                            \
            infiniopTensorDescriptor_t output_desc,                           \
            std::vector<infiniopTensorDescriptor_t> input_descs);             \
                                                                              \
        infiniStatus_t calculate(                                             \
            void *output,                                                     \
            std::vector<const void *> inputs,                                 \
            void *stream) const;                                              \
    };                                                                        \
    }

namespace op::elementwise {

// struct that stores data needed for elementwise operation
struct ElementwiseInfo {
    size_t output_size;
    size_t ndim;
    bool output_contiguous;
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    bool *input_contiguous;
    bool *input_broadcasted;
    size_t *output_shape;
    size_t **input_shapes;
    ptrdiff_t *output_strides;
    ptrdiff_t **input_strides;
    size_t input_size;

    ElementwiseInfo() = default;

    // Destructor to free allocated memory
    ~ElementwiseInfo() {
        delete[] input_contiguous;
        delete[] input_broadcasted;
        delete[] output_shape;
        delete[] output_strides;

        for (size_t i = 0; i < input_size; ++i) {
            delete[] input_shapes[i];
            delete[] input_strides[i];
        }
        delete[] input_shapes;
        delete[] input_strides;
    }

    ElementwiseInfo(const ElementwiseInfo &other)
        : output_size(other.output_size),
          ndim(other.ndim),
          output_contiguous(other.output_contiguous),
          input_size(other.input_size) {

        input_contiguous = new bool[input_size];
        std::memcpy(input_contiguous, other.input_contiguous, input_size * sizeof(*input_contiguous));

        input_broadcasted = new bool[input_size];
        std::memcpy(input_broadcasted, other.input_broadcasted, input_size * sizeof(*input_broadcasted));

        output_shape = new size_t[ndim];
        std::memcpy(output_shape, other.output_shape, ndim * sizeof(*output_shape));

        output_strides = new ptrdiff_t[ndim];
        std::memcpy(output_strides, other.output_strides, ndim * sizeof(*output_strides));

        input_shapes = new size_t *[input_size];
        for (size_t i = 0; i < input_size; ++i) {
            input_shapes[i] = new size_t[ndim];
            std::memcpy(input_shapes[i], other.input_shapes[i], ndim * sizeof(*input_shapes[i]));
        }

        input_strides = new ptrdiff_t *[input_size];
        for (size_t i = 0; i < input_size; ++i) {
            input_strides[i] = new ptrdiff_t[ndim];
            std::memcpy(input_strides[i], other.input_strides[i], ndim * sizeof(*input_strides[i]));
        }
    }

    ElementwiseInfo(ElementwiseInfo &&other) noexcept
        : output_size(other.output_size),
          ndim(other.ndim),
          output_contiguous(other.output_contiguous),
          input_contiguous(other.input_contiguous),
          input_broadcasted(other.input_broadcasted),
          output_shape(other.output_shape),
          input_shapes(other.input_shapes),
          output_strides(other.output_strides),
          input_strides(other.input_strides),
          input_size(other.input_size) {
        other.input_contiguous = nullptr;
        other.input_broadcasted = nullptr;
        other.output_shape = nullptr;
        other.input_shapes = nullptr;
        other.output_strides = nullptr;
        other.input_strides = nullptr;
        other.input_size = 0;
    }

    ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete;
    ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete;
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
};

inline infiniStatus_t createElementwiseInfo(
    ElementwiseInfo &info,
    infiniopTensorDescriptor_t output_desc,
    std::vector<infiniopTensorDescriptor_t> input_descs) {

    if (!output_desc || input_descs.empty()) {
        return INFINI_STATUS_BAD_PARAM;
    }

    // Destination cannot have broadcast setup
    if (output_desc->hasBroadcastDim()) {
        return INFINI_STATUS_BAD_TENSOR_STRIDES;
    }

149
150
    info.input_size = input_descs.size();
    info.ndim = output_desc->ndim();
151
152
153
    info.output_size = output_desc->numel();
    info.output_contiguous = output_desc->isContiguous();

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    // Allocate memory for arrays
    info.input_contiguous = new bool[info.input_size];
    info.input_broadcasted = new bool[info.input_size];
    info.output_shape = new size_t[info.ndim];
    info.output_strides = new ptrdiff_t[info.ndim];
    info.input_shapes = new size_t *[info.input_size];
    info.input_strides = new ptrdiff_t *[info.input_size];

    // Fill arrays
    const auto output_shape = output_desc->shape();
    const auto output_strides = output_desc->strides();
    std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape));
    std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides));

    for (size_t i = 0; i < info.input_size; ++i) {
        auto &desc = input_descs[i];
        info.input_contiguous[i] = desc->isContiguous();
        info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim());

        info.input_shapes[i] = new size_t[desc->ndim()];
        const auto &in_shape = desc->shape();
        std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i]));

        info.input_strides[i] = new ptrdiff_t[desc->ndim()];
        const auto &in_strides = desc->strides();
        std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i]));
180
181
182
183
184
185
186
187
    }

    return INFINI_STATUS_SUCCESS;
}

} // namespace op::elementwise

#endif // __INFINIOP_ELEMENTWISE_H__