#include #include #include #include #include #include #include #include #ifdef USE_CUDA #include #endif #include #include "chatglm.h" #include "chatglm_c.h" //static int multi_round_flag = 0; static double GetSpan(std::chrono::system_clock::time_point time1, std::chrono::system_clock::time_point time2) { auto duration = std::chrono::duration_cast (time2 - time1); return double(duration.count()) * std::chrono::microseconds::period::num / std::chrono::microseconds::period::den; }; std::map modelDict = { {"chatglm", 0}, {"chatglm—c", 1} }; struct RunConfig { int model = 0; // 模型类型, 0 chatglm-c++,1 chatglm-c std::string path = "chatglm-6b-int8.bin"; // 模型文件路径 int device = 0; int multi_round_flag = 0; // int threads = 4; // 使用的线程数 // bool lowMemMode = false; // 是否使用低内存模式 }; void Usage() { std::cout << "Usage:" << std::endl; std::cout << "[-h|--help]: 显示帮助" << std::endl; std::cout << "<-p|--path> : 模型文件的路径" << std::endl; std::cout << "<-m|--model> : 接口类型,默认为0,可以设置为0(chatglm c++接口),1(chatglm c接口)" << std::endl; std::cout << "<-d|--device> : 推理使用的DCU设备号" << std::endl; std::cout << "<-r|--multi-round> : 是否启用多轮方式对话,默认为0,可以设置为0(单轮对话),1(多轮对话)" << std::endl; } void ParseArgs(int argc, char **argv, RunConfig &config) { std::vector sargv; for (int i = 0; i < argc; i++) { sargv.push_back(std::string(argv[i])); } for (int i = 1; i < argc; i++) { if (sargv[i] == "-h" || sargv[i] == "--help") { Usage(); exit(0); } else if (sargv[i] == "-m" || sargv[i] == "--model") { if (modelDict.find(sargv[i + 1]) != modelDict.end()) { config.model = modelDict[sargv[++i]]; } else { config.model = atoi(sargv[++i].c_str()); } } else if (sargv[i] == "-p" || sargv[i] == "--path") { config.path = sargv[++i]; } else if (sargv[i] == "-d" || sargv[i] == "--device") { config.device = atoi(sargv[++i].c_str()); } else if (sargv[i] == "-r" || sargv[i] == "--multi-round") { config.multi_round_flag = atoi(sargv[++i].c_str()); } else { Usage(); exit(-1); } } } bool fileExists(const std::string& filename) { std::ifstream file(filename); return file.good(); } //获取utf-8字符个数(utf-8下,英文字符一个站位一个字节,中文字符一个站位3个字节) static int getUtf8LetterNumber(const char *s) { int i = 0, j = 0; while (s[i]) { if ((s[i] & 0xc0) != 0x80) j++; i++; } return j; } int chat_history(fastllm::ChatGLMModel* chatGlm, const char* input_Str) { static int sRound = 0; static std::string history; static int tokens = 0; std::string ret = ""; std::string input(input_Str); if (input == "reset") { history = ""; sRound = 0; return 0; } history += ("[Round " + std::to_string(sRound++) + "]\n问:" + input); if(getUtf8LetterNumber(history.c_str()) > 2048) { history = ""; sRound = 0; } tokens = 0; auto prompt = sRound > 1 ? history : input; auto st = std::chrono::system_clock::now(); ret = chatGlm->Response((prompt), [](int index, const char* content) { if (index == 0) { printf("ChatGLM:%s", content); tokens += 1; } if (index > 0) { printf("%s", content); tokens += 1; } if (index == -1) { printf("\n"); } }); float spend = GetSpan(st, std::chrono::system_clock::now()); //字数统计 int str_len = getUtf8LetterNumber(ret.c_str()); printf("word_count = %d, token_count = %d, spend = %fs, word/s = %f, tokens/s = %f: .\n", str_len, tokens, spend, str_len/spend, tokens/spend); history += ("\n答:" + ret + "\n"); return ret.length(); } int main(int argc, char **argv) { RunConfig config; ParseArgs(argc, argv, config); if(!fileExists(config.path)){ printf("model path is not exist!\n"); return -1; } if (config.model == 0) { #ifdef USE_CUDA cudaSetDevice(config.device); #endif fastllm::ChatGLMModel chatGlm; chatGlm.LoadFromFile(config.path); chatGlm.WarmUp(); static int tokens = 0; while (true) { printf("用户: "); std::string input; std::getline(std::cin, input); if (input == "stop") { break; } if(0 == config.multi_round_flag){ tokens = 0; auto st = std::chrono::system_clock::now(); std::string ret = chatGlm.Response((input), [](int index, const char* content) { if (index == 0) { printf("ChatGLM:%s", content); tokens += 1; } if (index > 0) { printf("%s", content); tokens += 1; } if (index == -1) { printf("\n"); } }); float spend = GetSpan(st, std::chrono::system_clock::now()); //字数统计 int str_len = getUtf8LetterNumber(ret.c_str()); printf("word_count = %d, token_count = %d, spend = %fs, word/s = %f, tokens/s = %f: .\n", str_len, tokens, spend, str_len/spend, tokens/spend); } else{ chat_history(&chatGlm, input.c_str()); } } } else if (config.model == 1) { void* modelEngine = NULL; initLLMEngine(modelEngine, config.path.c_str(), config.device); char *output = NULL; while (true) { printf("用户: "); std::string input; std::getline(std::cin, input); //input = "晚上睡不着怎么办"; if (input == "stop") { break; } chat(modelEngine, input.c_str(), output); printf("ChatGLM:%s\n", output); } releaseLLMEngine(modelEngine); } else { Usage(); exit(-1); } return 0; }