binary.h 4.67 KB
Newer Older
1
2
3
#ifndef __INFINIOP_BINARY_H__
#define __INFINIOP_BINARY_H__

4
#include "../operator.h"
5
#include "../tensor.h"
6
#include <algorithm>
7
8
#include <numeric>

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
/**
 * 该类的设计基于 matmul.h 中 YdrMaster 设计的 DESCRIPTOR 宏。
 */

#define BINARY_DESCRIPTOR(OP, NAMESPACE)                  \
                                                          \
    namespace op::OP::NAMESPACE {                         \
    class Descriptor final : public InfiniopDescriptor {  \
        struct Opaque;                                    \
        Opaque *_opaque;                                  \
        infiniDtype_t _dtype;                             \
        op::binary::BinaryInfo _info;                     \
                                                          \
        Descriptor(                                       \
            infiniDtype_t dtype,                          \
            op::binary::BinaryInfo info,                  \
            Opaque *opaque,                               \
            infiniDevice_t device_type,                   \
            int device_id)                                \
            : InfiniopDescriptor{device_type, device_id}, \
              _opaque(opaque),                            \
              _dtype(dtype),                              \
              _info(info) {}                              \
                                                          \
    public:                                               \
        ~Descriptor();                                    \
                                                          \
        static infiniStatus_t create(                     \
            infiniopHandle_t handle,                      \
            Descriptor **desc_ptr,                        \
            infiniopTensorDescriptor_t c_desc,            \
            infiniopTensorDescriptor_t a_desc,            \
            infiniopTensorDescriptor_t b_desc);           \
                                                          \
        infiniStatus_t calculate(                         \
            void *c,                                      \
            const void *a,                                \
            const void *b,                                \
            void *stream) const;                          \
    };                                                    \
    }

51
52
53
54
55
56
namespace op::binary {

// Stores metadata for binary operations on CPU
struct BinaryInfo {
    size_t c_data_size;
    size_t ndim;
57
    bool contiguous;
58
59
60
61
62
63
64
    bool broadcasted;
    std::vector<size_t> c_shape;
    std::vector<size_t> a_shape;
    std::vector<size_t> b_shape;
    std::vector<ptrdiff_t> c_strides;
    std::vector<ptrdiff_t> a_strides;
    std::vector<ptrdiff_t> b_strides;
65
};
66

67
68
69
70
inline infiniStatus_t createBinaryInfo(BinaryInfo &info,
                                       infiniopTensorDescriptor_t c_desc,
                                       infiniopTensorDescriptor_t a_desc,
                                       infiniopTensorDescriptor_t b_desc) {
71

72
73
    if (!c_desc || !a_desc || !b_desc) {
        return INFINI_STATUS_BAD_PARAM;
74
    }
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

    const auto &c_shape = c_desc->shape();
    const auto &a_shape = a_desc->shape();
    const auto &b_shape = b_desc->shape();
    const auto &c_strides = c_desc->strides();
    const auto &a_strides = a_desc->strides();
    const auto &b_strides = b_desc->strides();

    info.c_data_size = std::accumulate(c_shape.begin(), c_shape.end(), size_t(1), std::multiplies<size_t>());
    info.ndim = c_desc->ndim();
    info.contiguous = c_desc->isContiguous() && a_desc->isContiguous() && b_desc->isContiguous();

    // Check if a tensor is broadcasted by checking its shape and strides
    auto isBroadcasted = [](const std::vector<size_t> &shape, const std::vector<ptrdiff_t> &strides) {
        return std::any_of(
            shape.begin(), shape.end(),
            [&, i = 0](const auto &) mutable {
                return shape[i] != 1 && strides[i++] == 0;
            });
    };

    // Destination cannot have broadcast setup
    if (isBroadcasted(c_shape, c_strides)) {
        return INFINI_STATUS_BAD_TENSOR_STRIDES;
    }
    const bool ndim_match = (c_desc->ndim() == a_desc->ndim()) && (c_desc->ndim() == b_desc->ndim());
    info.broadcasted = !info.contiguous && (!ndim_match || isBroadcasted(a_shape, a_strides) || isBroadcasted(b_shape, b_strides));

    info.c_shape = std::move(c_shape);
    info.a_shape = std::move(a_shape);
    info.b_shape = std::move(b_shape);
    info.c_strides = std::move(c_strides);
    info.a_strides = std::move(a_strides);
    info.b_strides = std::move(b_strides);

    return INFINI_STATUS_SUCCESS;
}

113
114
115
} // namespace op::binary

#endif // __INFINIOP_BINARY_H__