// // Created by huangyuyang on 6/1/23. // #ifndef FASTLLM_LLAMA_H #define FASTLLM_LLAMA_H #include "basellm.h" #include "cmath" #include namespace fastllm { class LlamaModel: public basellm { public: LlamaModel (); // 构造函数 // 推理 virtual int Forward( const Data &inputIds, const Data &attentionMask, const Data &positionIds, std::vector > &pastKeyValues, const GenerationConfig &generationConfig = GenerationConfig(), const LastTokensManager &lastTokens = LastTokensManager(), std::vector *logits = nullptr); std::vector ForwardBatch( int batch, const Data &inputIds, const Data &attentionMask, const Data &positionIds, std::vector > &pastKeyValues, const GenerationConfig &generationConfig = GenerationConfig(), const LastTokensManager &lastTokens = LastTokensManager(), std::vector *> *logits = nullptr); std::vector ForwardBatch( int batch, const Data &inputIds, const std::vector &attentionMask, const std::vector &positionIds, const std::vector &seqLens, std::vector > &pastKeyValues, const std::vector &generationConfigs, const LastTokensManager &lastTokens = LastTokensManager(), std::vector *> *logits = nullptr); virtual std::string Response(const std::string& input, RuntimeResult retCb, const GenerationConfig &generationConfig = GenerationConfig()); // 根据给出的内容回复 virtual void ResponseBatch(const std::vector &inputs, std::vector &outputs, RuntimeResultBatch retCb, const GenerationConfig &generationConfig = GenerationConfig()); virtual int LaunchResponseTokens(const std::vector &inputTokens, const GenerationConfig &generationConfig = GenerationConfig()); // 启动一个response任务,返回分配的handleId virtual int FetchResponseTokens(int handelId); // 获取指定handle的输出, -1代表输出结束了 virtual void WarmUp(); // 预热 virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history bool is_nsql = false; }; } #endif //FASTLLM_LLAMA_H