#ifndef MODEL_JIUGE_H #define MODEL_JIUGE_H #include #include #include #include struct JiugeModel; typedef struct { infiniDtype_t dt_logits; size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc; float epsilon, theta; uint32_t end_token; } JiugeMeta; typedef struct { size_t nlayer; infiniDtype_t dt_norm, dt_mat; // [dvoc, d] const void *input_embd; // [d] const void *output_norm; // [dvoc, d] const void *output_embd; // nlayer * [d] const void *const *attn_norm; // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh, d] const void *const *attn_qkv; // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh] const void *const *attn_qkv_b; // nlayer * [ndev, d, nkvh / ndev * dh] const void *const *attn_o; // nlayer * [d] const void *const *ffn_norm; // nlayer * [ndev, 2 * di / ndev, d] const void *const *ffn_gate_up; // nlayer * [ndev, d, di / ndev] const void *const *ffn_down; } JiugeWeights; //////////////////// APIs /////////////////////// /// @brief 创建模型 /// @param device 协处理器种类 /// @param ndev 协处理器数量 /// @param dev_ids 协处理器编号,长度为 ndev __C __export struct JiugeModel * createJiugeModel(const JiugeMeta *, const JiugeWeights *, infiniDevice_t device, int ndev, const int *dev_ids); /// @brief 销毁模型 __C __export void destroyJiugeModel(struct JiugeModel *); /// @brief 创建 KV Cache __C __export struct KVCache * createKVCache(const struct JiugeModel *); /// @brief 复制 KV Cache __C __export struct KVCache * duplicateKVCache(const struct JiugeModel *, const struct KVCache *, uint32_t seq_len); /// @brief 销毁 KV Cache __C __export void dropKVCache(const struct JiugeModel *, struct KVCache *); /// @brief 文本生成 /// @param tokens 输入 token /// @param ntok 输入 token 数量 /// @param req_pos 每个请求的起始位置 /// @param output 输出 token 地址 /// @param max_step 输出 token 最大数量 /// @param temperature 采样温度(0. 表示贪心采样) /// @param topk 采样 topk(1 表示贪心采样) /// @param topp 采样 topp __C __export void generate(struct JiugeModel *, struct KVCache *, const uint32_t *tokens, uint32_t ntok, uint32_t req_pos, uint32_t *output, uint32_t max_step, float temperature, uint32_t topk, float topp); /// @brief 批次推理一轮 /// @param tokens 输入 token 地址 /// @param ntok 输入 token 数量 /// @param nreq 请求数量 /// @param req_lens 每个请求的 token 数量 /// @param req_pos 每个请求的起始位置 /// @param kv_caches 每个请求的 KV Cache /// @param ans 输出 token 数组,每个请求一个输出,长度至少为nreq /// @param temperature 采样温度(0. 表示贪心采样) /// @param topk 采样 topk(1 表示贪心采样) /// @param topp 采样 topp __C __export void inferBatch(struct JiugeModel *, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, uint32_t *output, float temperature, uint32_t topk, float topp); #endif