elementwise.h 7.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#ifndef __INFINIOP_ELEMENTWISE_H__
#define __INFINIOP_ELEMENTWISE_H__

#include "../operator.h"
#include "../tensor.h"
#include <algorithm>
#include <memory>
#include <numeric>
#include <vector>

#define DEVICE_IMPL(NAMESPACE)                                                   \
                                                                                 \
    namespace op::elementwise::NAMESPACE {                                       \
    class DeviceImpl final {                                                     \
        struct Opaque;                                                           \
        std::unique_ptr<Opaque> _opaque;                                         \
                                                                                 \
        DeviceImpl(Opaque *opaque) : _opaque(opaque) {}                          \
                                                                                 \
    public:                                                                      \
        ~DeviceImpl() = default;                                                 \
                                                                                 \
        template <typename... Args>                                              \
        static infiniStatus_t create(                                            \
            DeviceImpl **device_info,                                            \
            Args &&...args);                                                     \
                                                                                 \
        /* Invoke elementwise operation when all inputs have the same type */    \
        template <typename Op, typename Tdata, typename... Args>                 \
        void calculate(                                                          \
            const op::elementwise::ElementwiseInfo &info,                        \
            void *output,                                                        \
            const std::vector<const void *> &inputs,                             \
            Args &&...args);                                                     \
                                                                                 \
        /* Invoke elementwise operation for different input types */             \
        template <typename Op, typename Tout, typename... Tin,                   \
                  typename... Args,                                              \
                  std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> \
        void calculate(                                                          \
            const op::elementwise::ElementwiseInfo &info,                        \
            void *output,                                                        \
            const std::vector<const void *> &inputs,                             \
            Args &&...args);                                                     \
    };                                                                           \
    }

#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE)                                 \
                                                                              \
    namespace op::OP::NAMESPACE {                                             \
    class Descriptor final : public InfiniopDescriptor {                      \
        infiniDtype_t _dtype;                                                 \
        op::elementwise::ElementwiseInfo _info;                               \
        std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
                                                                              \
        Descriptor(                                                           \
            infiniDtype_t dtype,                                              \
            op::elementwise::ElementwiseInfo info,                            \
            op::elementwise::NAMESPACE::DeviceImpl *device_info,              \
            infiniDevice_t device_type,                                       \
            int device_id)                                                    \
            : InfiniopDescriptor{device_type, device_id},                     \
              _dtype(dtype),                                                  \
              _info(info),                                                    \
              _device_info(device_info) {}                                    \
                                                                              \
    public:                                                                   \
        ~Descriptor();                                                        \
                                                                              \
        static infiniStatus_t create(                                         \
            infiniopHandle_t handle,                                          \
            Descriptor **desc_ptr,                                            \
            infiniopTensorDescriptor_t output_desc,                           \
            std::vector<infiniopTensorDescriptor_t> input_descs);             \
                                                                              \
        infiniStatus_t calculate(                                             \
            void *output,                                                     \
            std::vector<const void *> inputs,                                 \
            void *stream) const;                                              \
    };                                                                        \
    }

namespace op::elementwise {

// struct that stores data needed for elementwise operation
struct ElementwiseInfo {
    size_t output_size;
    size_t ndim;
    bool output_contiguous;
    std::vector<bool> input_contiguous;
    std::vector<bool> input_broadcasted;
    std::vector<size_t> output_shape;
    std::vector<std::vector<size_t>> input_shapes;
    std::vector<ptrdiff_t> output_strides;
    std::vector<std::vector<ptrdiff_t>> input_strides;
};

inline infiniStatus_t createElementwiseInfo(
    ElementwiseInfo &info,
    infiniopTensorDescriptor_t output_desc,
    std::vector<infiniopTensorDescriptor_t> input_descs) {

    if (!output_desc || input_descs.empty()) {
        return INFINI_STATUS_BAD_PARAM;
    }

    // Destination cannot have broadcast setup
    if (output_desc->hasBroadcastDim()) {
        return INFINI_STATUS_BAD_TENSOR_STRIDES;
    }

    const size_t input_size = input_descs.size();
    const size_t out_ndim = output_desc->ndim();

    // Intializing the ElementwiseInfo struct
    info.output_size = output_desc->numel();
    info.ndim = out_ndim;
    info.output_contiguous = output_desc->isContiguous();

    for (const auto &desc : input_descs) {
        info.input_contiguous.emplace_back(desc->isContiguous());
    }

    for (size_t i = 0; i < input_size; ++i) {
        const auto &desc = input_descs[i];
        info.input_broadcasted.emplace_back(!info.input_contiguous[i] && (desc->ndim() != out_ndim || desc->hasBroadcastDim()));
    }

    info.output_shape = std::move(output_desc->shape());
    info.output_strides = std::move(output_desc->strides());
    for (const auto &desc : input_descs) {
        info.input_shapes.emplace_back(desc->shape());
        info.input_strides.emplace_back(desc->strides());
    }

    return INFINI_STATUS_SUCCESS;
}

} // namespace op::elementwise

#endif // __INFINIOP_ELEMENTWISE_H__