info.h 2.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#ifndef __SWIGLU_CUDA_INFO_H__
#define __SWIGLU_CUDA_INFO_H__

#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"

namespace op::swiglu_cuda {

class SwiGLUCudaInfo {
    SwiGLUCudaInfo() = default;

public:
    infiniDtype_t dtype;
    size_t length;
    size_t batch, seq_len, hidden_dim;
    ptrdiff_t c_strides_0, c_strides_1;
    ptrdiff_t a_strides_0, a_strides_1;
    ptrdiff_t b_strides_0, b_strides_1;

    static utils::Result<SwiGLUCudaInfo> createSwiGLUCudaInfo(infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) {
        auto dtype = c_desc->dtype();
        if (dtype != a_desc->dtype() || dtype != b_desc->dtype()) {
            return INFINI_STATUS_BAD_TENSOR_DTYPE;
        }
        CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);

        auto shape = c_desc->shape();
        CHECK_SAME_SHAPE(shape, a_desc->shape(), b_desc->shape());

        auto ndim = c_desc->ndim();
        size_t hidden_dim = shape[ndim - 1];
        size_t seq_len = shape[ndim - 2];
        size_t batch = (ndim == 3 ? shape[0] : 1);

        size_t length = batch * seq_len * hidden_dim;

        ptrdiff_t c_strides_0 = (ndim == 3 ? c_desc->strides()[0] : 0);
        ptrdiff_t c_strides_1 = (ndim == 3 ? c_desc->strides()[1] : c_desc->strides()[0]);
        ptrdiff_t a_strides_0 = (ndim == 3 ? a_desc->strides()[0] : 0);
        ptrdiff_t a_strides_1 = (ndim == 3 ? a_desc->strides()[1] : a_desc->strides()[0]);
        ptrdiff_t b_strides_0 = (ndim == 3 ? b_desc->strides()[0] : 0);
        ptrdiff_t b_strides_1 = (ndim == 3 ? b_desc->strides()[1] : b_desc->strides()[0]);

        return utils::Result<SwiGLUCudaInfo>(SwiGLUCudaInfo{
            dtype,
            length,
            batch,
            seq_len,
            hidden_dim,
            c_strides_0,
            c_strides_1,
            a_strides_0,
            a_strides_1,
            b_strides_0,
            b_strides_1});
    }
};

} // namespace op::swiglu_cuda

#endif // __SWIGLU_CUDA_INFO_H__