elementwise.h 7.57 KB
Newer Older
1
2
3
#ifndef __INFINIOP_ELEMENTWISE_H__
#define __INFINIOP_ELEMENTWISE_H__

4
#include "../../utils.h"
5
6
7
#include "../operator.h"
#include "../tensor.h"
#include <algorithm>
8
9
#include <cstring>
#include <iostream>
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <memory>
#include <numeric>
#include <vector>

#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE)                                 \
                                                                              \
    namespace op::OP::NAMESPACE {                                             \
    class Descriptor final : public InfiniopDescriptor {                      \
        infiniDtype_t _dtype;                                                 \
        op::elementwise::ElementwiseInfo _info;                               \
        std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
                                                                              \
        Descriptor(                                                           \
            infiniDtype_t dtype,                                              \
            op::elementwise::ElementwiseInfo info,                            \
            op::elementwise::NAMESPACE::DeviceImpl *device_info,              \
            infiniDevice_t device_type,                                       \
            int device_id)                                                    \
            : InfiniopDescriptor{device_type, device_id},                     \
              _dtype(dtype),                                                  \
30
              _info(std::move(info)),                                         \
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
              _device_info(device_info) {}                                    \
                                                                              \
    public:                                                                   \
        ~Descriptor();                                                        \
                                                                              \
        static infiniStatus_t create(                                         \
            infiniopHandle_t handle,                                          \
            Descriptor **desc_ptr,                                            \
            infiniopTensorDescriptor_t output_desc,                           \
            std::vector<infiniopTensorDescriptor_t> input_descs);             \
                                                                              \
        infiniStatus_t calculate(                                             \
            void *output,                                                     \
            std::vector<const void *> inputs,                                 \
            void *stream) const;                                              \
    };                                                                        \
    }

namespace op::elementwise {

51
52
53
54
55
56
57
58
59
60
61
/**
 * @brief Stores the metadata required for performing an elementwise operation.
 *
 * This struct encapsulates shape, stride, and layout information for both
 * output and multiple input tensors involved in an elementwise operation.
 *
 * Memory is manually managed and freed in the destructor.
 * Supports move construction but disallows copy construction and copy/move assignment.
 *
 * Use ElementwiseInfo::create(...) to safely construct an instance from tensor descriptors.
 */
62
struct ElementwiseInfo {
63
64
65
66
private:
    ElementwiseInfo() = default;

public:
67
68
69
    size_t output_size;
    size_t ndim;
    bool output_contiguous;
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    bool *input_contiguous;
    bool *input_broadcasted;
    size_t *output_shape;
    size_t **input_shapes;
    ptrdiff_t *output_strides;
    ptrdiff_t **input_strides;
    size_t input_size;

    ~ElementwiseInfo() {
        delete[] input_contiguous;
        delete[] input_broadcasted;
        delete[] output_shape;
        delete[] output_strides;

        for (size_t i = 0; i < input_size; ++i) {
            delete[] input_shapes[i];
            delete[] input_strides[i];
        }
        delete[] input_shapes;
        delete[] input_strides;
    }

    ElementwiseInfo(ElementwiseInfo &&other) noexcept
        : output_size(other.output_size),
          ndim(other.ndim),
          output_contiguous(other.output_contiguous),
          input_contiguous(other.input_contiguous),
          input_broadcasted(other.input_broadcasted),
          output_shape(other.output_shape),
          input_shapes(other.input_shapes),
          output_strides(other.output_strides),
          input_strides(other.input_strides),
          input_size(other.input_size) {
        other.input_contiguous = nullptr;
        other.input_broadcasted = nullptr;
        other.output_shape = nullptr;
        other.input_shapes = nullptr;
        other.output_strides = nullptr;
        other.input_strides = nullptr;
        other.input_size = 0;
    }

112
    ElementwiseInfo(const ElementwiseInfo &other) = delete;
113
114
    ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete;
    ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete;
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    using ResultType = utils::Result<ElementwiseInfo>;

    /**
     * @brief Construct ElementwiseInfo from output and input tensor descriptors.
     * @param output_desc Descriptor of the output tensor.
     * @param input_descs Descriptors of the input tensors.
     * @return Result<ElementwiseInfo> with the successfully constructed ElementwiseInfo,
     *         or the status code.
     */
    static ResultType create(
        infiniopTensorDescriptor_t output_desc,
        std::vector<infiniopTensorDescriptor_t> input_descs) {

        if (!output_desc || input_descs.empty()) {
            return INFINI_STATUS_BAD_PARAM;
        }
132

133
134
135
136
        // Destination cannot have broadcast setup
        if (output_desc->hasBroadcastDim()) {
            return INFINI_STATUS_BAD_TENSOR_STRIDES;
        }
137

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        ElementwiseInfo info;
        info.input_size = input_descs.size();
        info.ndim = output_desc->ndim();
        info.output_size = output_desc->numel();
        info.output_contiguous = output_desc->isContiguous();

        // Allocate memory for arrays
        info.input_contiguous = new bool[info.input_size];
        info.input_broadcasted = new bool[info.input_size];
        info.output_shape = new size_t[info.ndim];
        info.output_strides = new ptrdiff_t[info.ndim];
        info.input_shapes = new size_t *[info.input_size];
        info.input_strides = new ptrdiff_t *[info.input_size];

        // Fill arrays
        const auto output_shape = output_desc->shape();
        const auto output_strides = output_desc->strides();
        std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape));
        std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides));

        for (size_t i = 0; i < info.input_size; ++i) {
            auto &desc = input_descs[i];
            info.input_contiguous[i] = desc->isContiguous();
            info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim());

            info.input_shapes[i] = new size_t[desc->ndim()];
            const auto &in_shape = desc->shape();
            std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i]));

            info.input_strides[i] = new ptrdiff_t[desc->ndim()];
            const auto &in_strides = desc->strides();
            std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i]));
        }
171

172
        return ResultType(std::move(info));
173
    }
174
};
175
176
177
} // namespace op::elementwise

#endif // __INFINIOP_ELEMENTWISE_H__