matmul.h 2.73 KB
Newer Older
1
2
3
4
#ifndef __MATMUL_H__
#define __MATMUL_H__

#include "blas.h"
5
#include "infiniop/handle.h"
6
7
#include "infiniop/operator.h"

8
#define DESCRIPTOR(NAMESPACE)                             \
9
10
11
12
13
                                                          \
    namespace matmul::NAMESPACE {                         \
    class Descriptor final : public InfiniopDescriptor {  \
        struct Opaque;                                    \
        Opaque *_opaque;                                  \
14
15
        infiniDtype_t _dtype;                             \
        MatmulInfo _info;                                 \
16
17
                                                          \
        Descriptor(                                       \
18
19
            infiniDtype_t dtype,                          \
            MatmulInfo info,                              \
20
21
22
23
24
25
            size_t workspace_size_,                       \
            Opaque *opaque,                               \
            infiniDevice_t device_type,                   \
            int device_id)                                \
            : InfiniopDescriptor{device_type, device_id}, \
              _opaque(opaque),                            \
26
27
              _dtype(dtype),                              \
              _info(info),                                \
28
29
30
31
32
33
34
35
              workspace_size(workspace_size_) {}          \
                                                          \
    public:                                               \
        size_t workspace_size;                            \
                                                          \
        ~Descriptor();                                    \
                                                          \
        static infiniopStatus_t create(                   \
36
            infiniopHandle_t handle,                      \
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
            Descriptor **desc_ptr,                        \
            infiniopTensorDescriptor_t c_desc,            \
            infiniopTensorDescriptor_t a_desc,            \
            infiniopTensorDescriptor_t b_desc);           \
                                                          \
        infiniopStatus_t calculate(                       \
            void *workspace,                              \
            size_t workspace_size,                        \
            void *c,                                      \
            float beta,                                   \
            void const *a,                                \
            void const *b,                                \
            float alpha,                                  \
            void *stream) const;                          \
    };                                                    \
    }

#endif // __MATMUL_H__