jiuge.h 3.31 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
#ifndef MODEL_JIUGE_H
#define MODEL_JIUGE_H

#include <infiniccl.h>
#include <infiniop.h>
#include <infinirt.h>

#include <stdint.h>

struct JiugeModel;

typedef struct
{
PanZezhong's avatar
PanZezhong committed
14
    infiniDtype_t dt_logits;
PanZezhong's avatar
init  
PanZezhong committed
15
16
17
18
19
20
21
22
    size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc;
    float epsilon, theta;
    uint32_t end_token;
} JiugeMeta;

typedef struct
{
    size_t nlayer;
PanZezhong's avatar
PanZezhong committed
23
    infiniDtype_t dt_norm, dt_mat;
PanZezhong's avatar
PanZezhong committed
24
25
    // 0 if linear weights are passed as W, any other value if passed as W^T (default format in pytorch)
    int transpose_linear_weights;
PanZezhong's avatar
init  
PanZezhong committed
26
27
28
29
30
31
32
33
34
35
36
37
    // [dvoc, d]
    const void *input_embd;
    // [d]
    const void *output_norm;
    // [dvoc, d]
    const void *output_embd;
    // nlayer * [d]
    const void *const *attn_norm;
    // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh, d]
    const void *const *attn_qkv;
    // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh]
    const void *const *attn_qkv_b;
38
39
40
41
    // nlayer * [dh]
    const void *const *attn_q_norm;
    // nlayer * [dh]
    const void *const *attn_k_norm;
PanZezhong's avatar
init  
PanZezhong committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    // nlayer * [ndev, d, nkvh / ndev * dh]
    const void *const *attn_o;
    // nlayer * [d]
    const void *const *ffn_norm;
    // nlayer * [ndev, 2 * di / ndev, d]
    const void *const *ffn_gate_up;
    // nlayer * [ndev, d, di / ndev]
    const void *const *ffn_down;
} JiugeWeights;

//////////////////// APIs ///////////////////////
/// @brief 创建模型
/// @param device 协处理器种类
/// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev
57
__INFINI_C __export struct JiugeModel *
PanZezhong's avatar
init  
PanZezhong committed
58
59
60
61
62
63
64
createJiugeModel(const JiugeMeta *,
                 const JiugeWeights *,
                 infiniDevice_t device,
                 int ndev,
                 const int *dev_ids);

/// @brief 销毁模型
65
__INFINI_C __export void
PanZezhong's avatar
init  
PanZezhong committed
66
67
destroyJiugeModel(struct JiugeModel *);

PanZezhong's avatar
PanZezhong committed
68
/// @brief 批次推理一轮,并采样出新的 token
PanZezhong's avatar
init  
PanZezhong committed
69
70
71
72
73
74
75
76
77
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
Pan Zezhong's avatar
Pan Zezhong committed
78
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
79
__INFINI_C __export void
blkmjsian's avatar
blkmjsian committed
80
81
82
83
84
85
inferBatchJiuge(struct JiugeModel *,
                const uint32_t *tokens, uint32_t ntok,
                const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
                struct KVCache **kv_caches,
                const float *temperature, const uint32_t *topk, const float *topp,
                uint32_t *output);
PanZezhong's avatar
init  
PanZezhong committed
86

PanZezhong's avatar
PanZezhong committed
87
88
89
90
91
92
93
94
/// @brief 批次推理一轮,输出 output embedding 后的 logits
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
95
__INFINI_C __export void
blkmjsian's avatar
blkmjsian committed
96
97
98
99
100
forwardBatchJiuge(struct JiugeModel *,
                  const uint32_t *tokens, uint32_t ntok,
                  const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
                  struct KVCache **kv_caches,
                  void *logits);
PanZezhong's avatar
PanZezhong committed
101

PanZezhong's avatar
init  
PanZezhong committed
102
#endif