embedding.h 2.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#ifndef __EMBEDDING_H__
#define __EMBEDDING_H__

#include "../../../utils.h"
#include "../../operator.h"

#define DESCRIPTOR(NAMESPACE)                             \
                                                          \
    namespace op::embedding::NAMESPACE {                  \
    class Descriptor final : public InfiniopDescriptor {  \
        struct Opaque;                                    \
        Opaque *_opaque;                                  \
        size_t _num_indices;                              \
        size_t _embedding_dim;                            \
        size_t _vocab_size;                               \
        infiniDtype_t _input_dtype;                       \
        infiniDtype_t _weight_dtype;                      \
                                                          \
        Descriptor(                                       \
            size_t num_indices,                           \
            size_t embedding_dim,                         \
            size_t vocab_size,                            \
            infiniDtype_t input_dtype,                    \
            infiniDtype_t weight_dtype,                   \
            Opaque *opaque,                               \
            infiniDevice_t device_type,                   \
            int device_id)                                \
            : InfiniopDescriptor{device_type, device_id}, \
              _opaque(opaque),                            \
              _num_indices(num_indices),                  \
              _embedding_dim(embedding_dim),              \
              _vocab_size(vocab_size),                    \
              _input_dtype(input_dtype),                  \
              _weight_dtype(weight_dtype) {}              \
                                                          \
    public:                                               \
        ~Descriptor();                                    \
                                                          \
        static infiniStatus_t create(                     \
            infiniopHandle_t handle,                      \
            Descriptor **desc_ptr,                        \
            infiniopTensorDescriptor_t output_desc,       \
            infiniopTensorDescriptor_t input_desc,        \
            infiniopTensorDescriptor_t weight_desc);      \
                                                          \
        infiniStatus_t calculate(                         \
            void *output,                                 \
            const void *input,                            \
            const void *weight,                           \
            void *stream) const;                          \
    };                                                    \
    }

#endif // __EMBEDDING_H__