elementwise.h 9.44 KB
Newer Older
1
2
3
#ifndef __INFINIOP_ELEMENTWISE_H__
#define __INFINIOP_ELEMENTWISE_H__

4
#include "../../utils.h"
5
6
7
#include "../operator.h"
#include "../tensor.h"
#include <algorithm>
8
#include <array>
9
10
#include <cstring>
#include <iostream>
11
12
13
14
#include <memory>
#include <numeric>
#include <vector>

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE)                                 \
                                                                              \
    namespace op::OP::NAMESPACE {                                             \
    class Descriptor final : public InfiniopDescriptor {                      \
        infiniDtype_t _dtype;                                                 \
        op::elementwise::ElementwiseInfo _info;                               \
        std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
        size_t _workspace_size;                                               \
                                                                              \
        Descriptor(                                                           \
            infiniDtype_t dtype,                                              \
            op::elementwise::ElementwiseInfo info,                            \
            op::elementwise::NAMESPACE::DeviceImpl *device_info,              \
            size_t workspace_size,                                            \
            infiniDevice_t device_type,                                       \
            int device_id)                                                    \
            : InfiniopDescriptor{device_type, device_id},                     \
              _dtype(dtype),                                                  \
              _info(std::move(info)),                                         \
              _device_info(std::move(device_info)),                           \
              _workspace_size(workspace_size) {}                              \
                                                                              \
    public:                                                                   \
        ~Descriptor();                                                        \
                                                                              \
        size_t workspaceSize() const { return _workspace_size; }              \
                                                                              \
        static infiniStatus_t create(                                         \
            infiniopHandle_t handle,                                          \
            Descriptor **desc_ptr,                                            \
            infiniopTensorDescriptor_t output_desc,                           \
            std::vector<infiniopTensorDescriptor_t> input_descs);             \
                                                                              \
        infiniStatus_t calculate(                                             \
            void *workspace, size_t workspace_size,                           \
            void *output,                                                     \
            std::vector<const void *> inputs,                                 \
            void *stream) const;                                              \
    };                                                                        \
54
55
56
57
    }

namespace op::elementwise {

58
59
60
61
62
63
64
65
66
67
68
/**
 * @brief Stores the metadata required for performing an elementwise operation.
 *
 * This struct encapsulates shape, stride, and layout information for both
 * output and multiple input tensors involved in an elementwise operation.
 *
 * Memory is manually managed and freed in the destructor.
 * Supports move construction but disallows copy construction and copy/move assignment.
 *
 * Use ElementwiseInfo::create(...) to safely construct an instance from tensor descriptors.
 */
69
struct ElementwiseInfo {
70
private:
71
    std::vector<size_t> _meta;
72
73
74
75
76
    size_t _output_size;
    size_t _input_size;
    size_t _ndim;
    bool _output_contiguous;

77
    ElementwiseInfo(std::vector<size_t> meta,
78
79
80
81
82
83
84
                    size_t output_size,
                    size_t input_size,
                    size_t ndim,
                    bool output_contiguous)
        : _meta(std::move(meta)), _output_size(output_size),
          _input_size(input_size), _ndim(ndim),
          _output_contiguous(output_contiguous) {}
85
86

public:
87
    // Get the Memory size of the meta data in bytes
88
    inline size_t getMetaMemSize() const {
89
        return _meta.size() * sizeof(size_t);
90
91
    }
    inline const int8_t *getMetaStart() const {
92
        return reinterpret_cast<const int8_t *>(_meta.data());
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    }
    inline size_t getOutputSize() const {
        return _output_size;
    }
    inline size_t getInputSize() const {
        return _input_size;
    }
    inline size_t getNdim() const {
        return _ndim;
    }
    inline bool isOutputContiguous() const {
        return _output_contiguous;
    }
    inline const size_t *getOutputShape() const {
        return reinterpret_cast<const size_t *>(_meta.data());
    }
    inline const ptrdiff_t *getOutputStrides() const {
        return reinterpret_cast<const ptrdiff_t *>(getOutputShape() + _ndim);
    }
    inline const size_t *getAllInputShapes() const {
        return reinterpret_cast<const size_t *>(getOutputStrides() + _ndim);
    }
    inline const size_t *getInputShape(const size_t &index) const {
        if (index < _input_size) {
            return reinterpret_cast<const size_t *>(getAllInputShapes() + index * _ndim);
        }
        return nullptr;
    }
    inline const ptrdiff_t *getAllInputStrides() const {
        return reinterpret_cast<const ptrdiff_t *>(getAllInputShapes() + _input_size * _ndim);
    }
    inline const ptrdiff_t *getInputStrides(const size_t &index) const {
        if (index < _input_size) {
            return reinterpret_cast<const ptrdiff_t *>(getAllInputStrides() + index * _ndim);
127
        }
128
129
130
131
132
133
134
135
        return nullptr;
    }
    inline const bool *getInputContiguous() const {
        return reinterpret_cast<const bool *>(getAllInputStrides() + _input_size * _ndim);
    }
    inline const bool *getInputBroadcasted() const {
        return reinterpret_cast<const bool *>(getInputContiguous() + _input_size);
    }
136

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    using ResultType = utils::Result<ElementwiseInfo>;

    /**
     * @brief Construct ElementwiseInfo from output and input tensor descriptors.
     * @param output_desc Descriptor of the output tensor.
     * @param input_descs Descriptors of the input tensors.
     * @return Result<ElementwiseInfo> with the successfully constructed ElementwiseInfo,
     *         or the status code.
     */
    static ResultType create(
        infiniopTensorDescriptor_t output_desc,
        std::vector<infiniopTensorDescriptor_t> input_descs) {

        if (!output_desc || input_descs.empty()) {
            return INFINI_STATUS_BAD_PARAM;
        }
153

154
155
156
157
        // Destination cannot have broadcast setup
        if (output_desc->hasBroadcastDim()) {
            return INFINI_STATUS_BAD_TENSOR_STRIDES;
        }
158

159
160
161
162
163
164
165
166
167
168
169
170
        auto input_size = input_descs.size();
        auto ndim = output_desc->ndim();
        auto output_size = output_desc->numel();
        auto output_contiguous = output_desc->isContiguous();

        // Allocate memory for meta
        auto shape_unit = output_desc->dim(0);
        auto stride_unit = output_desc->stride(0);
        size_t meta_mem_size = ndim * (sizeof(shape_unit) + sizeof(stride_unit))
                             + input_size * ndim * sizeof(shape_unit)
                             + input_size * ndim * sizeof(stride_unit)
                             + 2 * input_size * sizeof(bool);
171
        std::vector<size_t> meta(CEIL_DIV(meta_mem_size, sizeof(size_t)));
172
        int8_t *meta_ptr = reinterpret_cast<int8_t *>(meta.data());
173

174
175
176
        const auto output_shape = output_desc->shape();
        const auto output_strides = output_desc->strides();

177
178
179
180
181
182
183
        // Pointers to the sections within _meta
        size_t *output_shape_p = reinterpret_cast<size_t *>(meta_ptr);
        ptrdiff_t *output_strides_p = reinterpret_cast<ptrdiff_t *>(output_shape_p + ndim);
        size_t *input_shapes = reinterpret_cast<size_t *>(output_strides_p + ndim);
        ptrdiff_t *input_strides = reinterpret_cast<ptrdiff_t *>(input_shapes + input_size * ndim);
        bool *input_contiguous = reinterpret_cast<bool *>(input_strides + input_size * ndim);
        bool *input_broadcasted = input_contiguous + input_size;
184

185
186
187
        // Copy output shape and strides
        std::memcpy(output_shape_p, output_shape.data(), ndim * sizeof(*output_shape_p));
        std::memcpy(output_strides_p, output_strides.data(), ndim * sizeof(*output_strides_p));
188

189
190
191
192
193
194
195
196
197
        // Copy input shapes, strides, contiguous, and broadcasted flags
        for (size_t i = 0; i < input_size; ++i) {
            auto &desc = input_descs[i];
            const auto in_shape = desc->shape();
            const auto in_strides = desc->strides();
            std::memcpy(input_shapes + i * ndim, in_shape.data(), ndim * sizeof(*input_shapes));
            std::memcpy(input_strides + i * ndim, in_strides.data(), ndim * sizeof(*input_strides));
            input_contiguous[i] = desc->isContiguous();
            input_broadcasted[i] = !input_contiguous[i] && (desc->ndim() != ndim || desc->hasBroadcastDim());
198
        }
199

200
        ElementwiseInfo info(std::move(meta), output_size, input_size, ndim, output_contiguous);
201
        return ResultType(std::move(info));
202
    }
203
};
204
205
206
} // namespace op::elementwise

#endif // __INFINIOP_ELEMENTWISE_H__