cdist.h 1.69 KB
Newer Older
PanZezhong's avatar
PanZezhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#ifndef __INFINIOP_CDIST_API_H__
#define __INFINIOP_CDIST_API_H__

#include "../operator_descriptor.h"

// 定义 cdist 算子描述符类型
typedef struct InfiniopDescriptor *infiniopCdistDescriptor_t;

/**
 * @brief 创建 Cdist 算子描述符
 * @param handle 算子句柄
 * @param desc_ptr 指向返回的描述符指针
 * @param y 输出张量描述符 (Shape: M x N)
 * @param x1 输入张量1描述符 (Shape: M x D)
 * @param x2 输入张量2描述符 (Shape: N x D)
 * @param p 范数阶数 (L-p norm)
 */
__INFINI_C __export infiniStatus_t infiniopCreateCdistDescriptor(
    infiniopHandle_t handle,
    infiniopCdistDescriptor_t *desc_ptr, // 注意这里应该是具体类型的指针
    infiniopTensorDescriptor_t y_desc,
    infiniopTensorDescriptor_t x1_desc,
    infiniopTensorDescriptor_t x2_desc,
    double p);

/**
 * @brief 获取 Cdist 计算所需的临时空间大小
 */
__INFINI_C __export infiniStatus_t infiniopGetCdistWorkspaceSize(infiniopCdistDescriptor_t desc,
                                                                 size_t *size);

/**
 * @brief 执行 Cdist 计算
 * @param desc 算子描述符
 * @param workspace 临时空间指针
 * @param workspace_size 临时空间大小
 * @param y 输出数据指针
 * @param x1 输入1数据指针
 * @param x2 输入2数据指针
 * @param stream 计算流 (CUDA stream 等)
 */
__INFINI_C __export infiniStatus_t infiniopCdist(
    infiniopCdistDescriptor_t desc,
    void *workspace,
    size_t workspace_size,
    void *y,
    const void *x1,
    const void *x2,
    void *stream);

/**
 * @brief 销毁 Cdist 算子描述符
 */
__INFINI_C __export infiniStatus_t infiniopDestroyCdistDescriptor(infiniopCdistDescriptor_t desc);

#endif