rope.h 5.98 KB
Newer Older
PanZezhong's avatar
PanZezhong committed
1
2
3
4
5
6
#ifndef __ROPE_H__
#define __ROPE_H__

#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
7
#include "infiniop/ops/rope.h"
PanZezhong's avatar
PanZezhong committed
8

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#define DESCRIPTOR(NAMESPACE)                                    \
                                                                 \
    namespace op::rope::NAMESPACE {                              \
    class Descriptor final : public InfiniopDescriptor {         \
        struct Opaque;                                           \
        Opaque *_opaque;                                         \
        RoPEInfo _info;                                          \
        size_t _workspace_size;                                  \
                                                                 \
        Descriptor(                                              \
            RoPEInfo info,                                       \
            size_t workspace_size_,                              \
            Opaque *opaque,                                      \
            infiniDevice_t device_type,                          \
            int device_id)                                       \
            : InfiniopDescriptor{device_type, device_id},        \
              _opaque(opaque),                                   \
              _info(info),                                       \
              _workspace_size(workspace_size_) {}                \
                                                                 \
    public:                                                      \
        ~Descriptor();                                           \
                                                                 \
        size_t workspaceSize() const { return _workspace_size; } \
                                                                 \
        static infiniStatus_t create(                            \
            infiniopHandle_t handle,                             \
            Descriptor **desc_ptr,                               \
            infiniopTensorDescriptor_t y_desc,                   \
            infiniopTensorDescriptor_t x_desc,                   \
            infiniopTensorDescriptor_t pos_desc,                 \
            infiniopTensorDescriptor_t sin_desc,                 \
41
42
            infiniopTensorDescriptor_t cos_desc,                 \
            infiniopRoPEAlgo_t algo);                            \
43
44
45
46
47
48
49
50
51
52
53
                                                                 \
        infiniStatus_t calculate(                                \
            void *workspace,                                     \
            size_t workspace_size,                               \
            void *y,                                             \
            const void *x,                                       \
            const void *pos_ids,                                 \
            const void *sin_table,                               \
            const void *cos_table,                               \
            void *stream) const;                                 \
    };                                                           \
PanZezhong's avatar
PanZezhong committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    }

class RoPEInfo {
private:
    RoPEInfo() = default;

public:
    infiniDtype_t data_type, pos_type;
    size_t seqlen, nhead, dhead, table_len, table_dim;
    ptrdiff_t
        y_stride_seqlen,
        y_stride_nhead,
        x_stride_seqlen,
        x_stride_nhead;
68
    infiniopRoPEAlgo_t algo;
PanZezhong's avatar
PanZezhong committed
69

70
71
    static utils::Result<RoPEInfo>
    createRoPEInfo(
PanZezhong's avatar
PanZezhong committed
72
73
74
75
        infiniopTensorDescriptor_t y_desc,
        infiniopTensorDescriptor_t x_desc,
        infiniopTensorDescriptor_t pos_desc,
        infiniopTensorDescriptor_t sin_desc,
76
77
        infiniopTensorDescriptor_t cos_desc,
        infiniopRoPEAlgo_t algo) {
PanZezhong's avatar
PanZezhong committed
78
        CHECK_OR_RETURN(
79
            y_desc != nullptr && pos_desc != nullptr && sin_desc != nullptr && cos_desc != nullptr && algo < infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_COUNT,
PanZezhong's avatar
PanZezhong committed
80
81
82
83
84
85
            INFINI_STATUS_NULL_POINTER);

        const infiniDtype_t data_type = y_desc->dtype();
        const infiniDtype_t pos_type = pos_desc->dtype();
        CHECK_OR_RETURN(data_type == x_desc->dtype() && data_type == sin_desc->dtype() && data_type == cos_desc->dtype(),
                        INFINI_STATUS_BAD_TENSOR_DTYPE);
86
        CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
87
        CHECK_DTYPE_ANY_INT(pos_type);
PanZezhong's avatar
PanZezhong committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

        CHECK_OR_RETURN(y_desc->ndim() == 3
                            && x_desc->ndim() == 3
                            && pos_desc->ndim() == 1
                            && sin_desc->ndim() == 2
                            && cos_desc->ndim() == 2,
                        INFINI_STATUS_BAD_TENSOR_SHAPE);

        const auto seqlen = y_desc->dim(0),
                   nhead = y_desc->dim(1),
                   dhead = y_desc->dim(2),
                   table_len = sin_desc->dim(0),
                   table_dim = sin_desc->dim(1);

        CHECK_OR_RETURN(seqlen == x_desc->dim(0)
                            && seqlen == pos_desc->dim(0)
                            && nhead == x_desc->dim(1) && dhead == x_desc->dim(2)
                            && table_len == cos_desc->dim(0) && table_dim == cos_desc->dim(1),
                        INFINI_STATUS_BAD_TENSOR_SHAPE);

        CHECK_OR_RETURN(dhead == table_dim * 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
        // Last dimension of x and y must be contiguous
        CHECK_OR_RETURN(y_desc->stride(2) == 1 && x_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
        // sin table and cos table must be totally contiguous
112
        CHECK_OR_RETURN(sin_desc->isContiguous() && cos_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES);
PanZezhong's avatar
PanZezhong committed
113
114
115
116
117
118
119
120
121
122
123
124
125

        return utils::Result<RoPEInfo>(RoPEInfo{
            data_type,
            pos_type,
            seqlen,
            nhead,
            dhead,
            table_len,
            table_dim,
            y_desc->stride(0),
            y_desc->stride(1),
            x_desc->stride(0),
            x_desc->stride(1),
126
            algo,
PanZezhong's avatar
PanZezhong committed
127
128
129
130
131
        });
    }
};

#endif