"mmde/configs/_base_/models/spvcnn.py" did not exist on "7aa442d5f635889648e8570b77af46eb088c78d6"
jiuge.h 3.25 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#ifndef MODEL_JIUGE_H
#define MODEL_JIUGE_H

#include <infiniccl.h>
#include <infiniop.h>
#include <infinirt.h>

#include <stdint.h>

struct JiugeModel;

typedef struct
{
    infiniDtype_t dt_logits, dt_norm, dt_mat;
    size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc;
    float epsilon, theta;
    uint32_t end_token;
} JiugeMeta;

typedef struct
{
    size_t nlayer;
    // [dvoc, d]
    const void *input_embd;
    // [d]
    const void *output_norm;
    // [dvoc, d]
    const void *output_embd;
    // nlayer * [d]
    const void *const *attn_norm;
    // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh, d]
    const void *const *attn_qkv;
    // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh]
    const void *const *attn_qkv_b;
    // nlayer * [ndev, d, nkvh / ndev * dh]
    const void *const *attn_o;
    // nlayer * [d]
    const void *const *ffn_norm;
    // nlayer * [ndev, 2 * di / ndev, d]
    const void *const *ffn_gate_up;
    // nlayer * [ndev, d, di / ndev]
    const void *const *ffn_down;
} JiugeWeights;

//////////////////// APIs ///////////////////////
/// @brief 创建模型
/// @param device 协处理器种类
/// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev
__C __export struct JiugeModel *
createJiugeModel(const JiugeMeta *,
                 const JiugeWeights *,
                 infiniDevice_t device,
                 int ndev,
                 const int *dev_ids);

/// @brief 销毁模型
__C __export void
destroyJiugeModel(struct JiugeModel *);

/// @brief 创建 KV Cache
__C __export struct KVCache *
createKVCache(const struct JiugeModel *);

/// @brief 复制 KV Cache
__C __export struct KVCache *
duplicateKVCache(const struct JiugeModel *,
                 const struct KVCache *, uint32_t seq_len);

/// @brief 销毁 KV Cache
__C __export void
dropKVCache(const struct JiugeModel *,
            struct KVCache *);

/// @brief 文本生成
/// @param tokens 输入 token
/// @param ntok 输入 token 数量
/// @param req_pos 每个请求的起始位置
/// @param output 输出 token 地址
/// @param max_step 输出 token 最大数量
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
__C __export void
generate(struct JiugeModel *,
         struct KVCache *,
         const uint32_t *tokens, uint32_t ntok, uint32_t req_pos,
         uint32_t *output, uint32_t max_step,
         float temperature, uint32_t topk, float topp);

/// @brief 批次推理一轮
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param ans 输出 token 数组,每个请求一个输出,长度至少为nreq
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
__C __export void
inferBatch(struct JiugeModel *,
           const uint32_t *tokens, uint32_t ntok,
           const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
           struct KVCache **kv_caches,
           uint32_t *output,
           float temperature, uint32_t topk, float topp);

#endif