cdist.h 3.6 KB
Newer Older
PanZezhong's avatar
PanZezhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#ifndef __CDIST_H__
#define __CDIST_H__

#include "../../operator.h"
#include "info.h"

/**
 * # 关于 `cdist` 算子描述符的说明
 * * 仿照 GEMM 的 PImpl (Opaque) 设计模式,将硬件相关的执行上下文(如 CUDA Handle、计算流等)
 * 隐藏在 `Opaque` 结构体中,确保头文件在不同后端(CPU/NVIDIA/Ascend)间的一致性。
 */

#define DESCRIPTOR(NAMESPACE)                                       \
                                                                    \
    namespace op::cdist::NAMESPACE {                                \
    class Descriptor final : public InfiniopDescriptor {            \
        struct Opaque;                                              \
        Opaque *_opaque;                                            \
        infiniDtype_t _dtype;                                       \
        CdistInfo _info; /* 包含 M, N, D 维度信息 */          \
        size_t _workspace_size;                                     \
        double _p; /* 范数阶数,创建时固定 */             \
                                                                    \
        Descriptor(                                                 \
            infiniDtype_t dtype,                                    \
            CdistInfo info,                                         \
            double p,                                               \
            size_t workspace_size_,                                 \
            Opaque *opaque,                                         \
            infiniDevice_t device_type,                             \
            int device_id)                                          \
            : InfiniopDescriptor{device_type, device_id},           \
              _opaque(opaque),                                      \
              _dtype(dtype),                                        \
              _info(info),                                          \
              _workspace_size(workspace_size_),                     \
              _p(p) {}                                              \
                                                                    \
    public:                                                         \
        ~Descriptor();                                              \
                                                                    \
        size_t workspaceSize() const { return _workspace_size; }    \
                                                                    \
        static infiniStatus_t create(                               \
            infiniopHandle_t handle,                                \
            Descriptor **desc_ptr,                                  \
            infiniopTensorDescriptor_t y_desc,  /* 输出 (M, N) */ \
            infiniopTensorDescriptor_t x1_desc, /* 输入 (M, D) */ \
            infiniopTensorDescriptor_t x2_desc, /* 输入 (N, D) */ \
            double p);                                              \
                                                                    \
        infiniStatus_t calculate(                                   \
            void *workspace,                                        \
            size_t workspace_size,                                  \
            void *y, /* 结果矩阵 */                             \
            const void *x1,                                         \
            const void *x2,                                         \
            void *stream) const;                                    \
    };                                                              \
    }

#endif // __CDIST_H__