// // Created by siemon on 8/9/23. // #ifndef TEST_QWEN_H #define TEST_QWEN_H #include "basellm.h" namespace fastllm { class QWenModel : public basellm { public: QWenModel(); // 推理 virtual int Forward( const Data &inputIds, const Data &attentionMask, const Data &positionIds, std::vector > &pastKeyValues, const GenerationConfig &generationConfig = GenerationConfig(), const LastTokensManager &lastTokens = LastTokensManager(), std::vector *logits = nullptr); std::vector ForwardBatch( int batch, const Data &inputIds, const Data &attentionMask, const Data &positionIds, std::vector > &pastKeyValues, const GenerationConfig &generationConfig = GenerationConfig(), const LastTokensManager &lastTokens = LastTokensManager(), std::vector *> *logits = nullptr); std::vector ForwardBatch( int batch, const Data &inputIds, const std::vector &attentionMask, const std::vector &positionIds, const std::vector &seqLens, std::vector > &pastKeyValues, const std::vector &generationConfigs, const LastTokensManager &lastTokens = LastTokensManager(), std::vector *> *retLogits = nullptr); virtual std::string MakeInput(const std::string &history, int round, const std::string &input); virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); virtual void FillLLMInputs(std::vector > &inputTokens, const std::map ¶ms, Data &inputIds, Data &attentionMask, Data &positionIds); virtual void FillLLMInputsBatch(std::vector > &inputTokens, const std::vector > ¶ms, Data &inputIds, Data &attentionMask, Data &positionIds); virtual void WarmUp(); void UpdateRotaryPosEmb(float ntk_alpha); int seq_length; float ntk_alpha; bool use_log_attn; Data logn_list; private: std::string im_start = "<|im_start|>"; std::string im_end = "<|im_end|>"; }; } #endif //TEST_QWEN_H