// // Created by huangyuyang on 6/25/23. // #include "basellm.h" #include "utils.h" #include #include #ifdef USE_CUDA #include "fastllm-cuda.cuh" #endif namespace fastllm { int ResponseContextDict::CreateHandle() { locker.lock(); int newId = 0; while (dicts.find(newId) != dicts.end()) { newId++; } dicts[newId] = new ResponseContext(); locker.unlock(); return newId; } ResponseContext *ResponseContextDict::GetHandle(int handleId) { locker.lock(); ResponseContext *ret = dicts.find(handleId) != dicts.end() ? dicts[handleId] : nullptr; locker.unlock(); return ret; } void ResponseContextDict::RemoveHandle(int handleId) { locker.lock(); if (dicts.find(handleId) != dicts.end()) { delete dicts[handleId]; dicts.erase(handleId); } locker.unlock(); } void ResponseContext::Init(int blocks) { pastKeyValues.clear(); for (int i = 0; i < blocks; i++) { pastKeyValues.push_back(std::make_pair(Data(DataType::FLOAT32), Data(DataType::FLOAT32))); } intParams.clear(); currentTokens.clear(); while (resultTokenQueue.size() > 0){ resultTokenQueue.pop(); } isEnding = false; preTokens = 0; } std::string basellm::Response(const std::string &input, RuntimeResult retCb, const fastllm::GenerationConfig &generationConfig) { #ifdef USE_CUDA FastllmCudaClearBigBuffer(); #endif std::string prompt = input; #ifdef PY_API size_t pos = input.find_last_of("time_stamp:"); prompt = (generationConfig.enable_hash_id && pos != std::string::npos) ? input.substr(0, pos - 10) : input; size_t hash_id = std::hash{}(input); #endif Data inputIds, attentionMask, positionIds; Data inputTokenData = this->weight.tokenizer.Encode(prompt); std::vector > inputTokens; inputTokens.resize(1); for (int i = 0; i < inputTokenData.Count(0); i++) { inputTokens[0].push_back(((float *) inputTokenData.cpuData)[i]); } std::vector > pastKeyValues; for (int i = 0; i < block_cnt; i++) { pastKeyValues.push_back(std::make_pair(Data(DataType::FLOAT32), Data(DataType::FLOAT32))); } std::string retString = ""; std::vector results; LastTokensManager tokens(1, generationConfig.last_n); int promptLen = inputTokens[0].size(), index = 0; FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, inputIds, attentionMask, positionIds); while (true) { auto st = std::chrono::system_clock::now(); int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens); tokens.units[0].Push(ret); if (ret == eos_token_id) { break; } results.push_back(ret); std::string curString = weight.tokenizer.Decode( Data(DataType::FLOAT32, {(int) results.size()}, results)).c_str(); retString += curString; if (retCb) #ifdef PY_API { if (generationConfig.enable_hash_id) { std::stringstream ss; ss << retString << "hash_id:" << hash_id; retCb(index, pybind11::bytes(ss.str())); } else { retCb(index, pybind11::bytes(retString)); } } #else retCb(index, curString.c_str()); #endif index++; fflush(stdout); results.clear(); inputTokens[0] = std::vector {(float)ret}; FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, inputIds, attentionMask, positionIds); if (index == generationConfig.output_token_limit) { break; } // printf("len = %d, spend %f s.\n", len, GetSpan(st, std::chrono::system_clock::now())); } if (retCb) #ifdef PY_API { if (generationConfig.enable_hash_id) { std::stringstream ss; ss << retString << "hash_id:" << hash_id; retCb(-1, pybind11::bytes(ss.str())); } else { retCb(-1, pybind11::bytes(retString)); } } #else retCb(-1, retString.c_str()); #endif return retString; } void basellm::ResponseBatch(const std::vector &inputs, std::vector &outputs, RuntimeResultBatch retCb, const fastllm::GenerationConfig &generationConfig) { #ifdef USE_CUDA FastllmCudaClearBigBuffer(); #endif #ifdef PY_API std::vector prompts; std::vector < size_t > hash_ids; for (auto _input: inputs){ size_t hash_id = std::hash{}(_input); hash_ids.push_back(hash_id); size_t pos = _input.find_last_of("time_stamp:"); std::string prompt = (generationConfig.enable_hash_id && pos != std::string::npos) ? _input.substr(0, pos - 10) : _input; prompts.push_back(prompt); } #else std::vector prompts = inputs; #endif // 1. first Data inputIds, attentionMask, positionIds; int batch = prompts.size(); outputs.clear(); outputs.resize(batch, ""); std::vector > inputTokens; inputTokens.resize(batch); for (int i = 0; i < batch; i++) { Data now = this->weight.tokenizer.Encode(prompts[i]); for (int j = 0; j < now.Count(0); j++) { inputTokens[i].push_back(((float *) now.cpuData)[j]); } } std::vector > pastKeyValues; for (int i = 0; i < block_cnt; i++) { pastKeyValues.push_back(std::make_pair(Data(DataType::FLOAT32), Data(DataType::FLOAT32))); } std::vector > params; params.resize(batch); for (int i = 0; i < batch; i++) { params[i]["promptLen"] = (int)inputTokens[i].size(); } params[0]["index"] = 0; int index = 0; LastTokensManager tokensManager (batch, generationConfig.last_n); std::vector isEnding = std::vector (batch, false); FillLLMInputsBatch(inputTokens, params, inputIds, attentionMask, positionIds); while (true) { auto st = std::chrono::system_clock::now(); std::vector ret = ForwardBatch(batch, inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokensManager); for (int i = 0; i < batch; i++) { tokensManager.units[i].Push(ret[i]); } std::vector fret; std::vector results; int endingCount = 0; std::vector curStrings; for (int i = 0; i < batch; i++) { fret.push_back(ret[i]); inputTokens[i] = std::vector {(float)ret[i]}; if (ret[i] == eos_token_id) { isEnding[i] = true; } if (isEnding[i]) { curStrings.push_back(""); endingCount++; continue; } results.push_back(ret[i]); std::string curString = weight.tokenizer.Decode( Data(DataType::FLOAT32, {(int) results.size()}, results)).c_str(); outputs[i] += curString; curStrings.push_back(curString); results.clear(); } if (endingCount == batch) { break; } if (retCb) #ifdef PY_API { if (generationConfig.enable_hash_id) { std::vector rtnStrings; for (size_t i=0; i rtnStrings; for (size_t i=0; i rtnStrings; for (size_t i=0; i rtnStrings; for (size_t i=0; i basellm::ForwardBatch(int batch, const fastllm::Data &inputIds, const fastllm::Data &attentionMask, const fastllm::Data &positionIds, std::vector> &pastKeyValues, const fastllm::GenerationConfig &generationConfig, const fastllm::LastTokensManager &lastTokens, std::vector *> *retLogits) { printf("Unsupport forward batch.\n"); exit(0); } std::vector basellm::ForwardBatch(int batch, const fastllm::Data &inputIds, const std::vector &attentionMask, const std::vector &positionIds, const std::vector &seqLens, std::vector> &pastKeyValues, const std::vector &generationConfigs, const fastllm::LastTokensManager &lastTokens, std::vector *> *logits) { std::vector ret; int cur = 0; for (int i = 0; i < batch; i++) { std::vector > curKV; curKV.resize(this->block_cnt); for (int j = 0; j < this->block_cnt; j++) { Mul(*pastKeyValues[i * this->block_cnt + j].first, 1.0, curKV[j].first); Mul(*pastKeyValues[i * this->block_cnt + j].second, 1.0, curKV[j].second); } Data curInput; Split(inputIds, 1, cur, cur + seqLens[i], curInput); LastTokensManager curTokens; curTokens.units.push_back(lastTokens.units[i]); ret.push_back(this->Forward(curInput, *attentionMask[i], *positionIds[i], curKV, generationConfigs[i], curTokens)); for (int j = 0; j < this->block_cnt; j++) { Mul(curKV[j].first, 1.0, *pastKeyValues[i * this->block_cnt + j].first); Mul(curKV[j].second, 1.0, *pastKeyValues[i * this->block_cnt + j].second); } } return ret; } int basellm::LaunchResponseTokens(const std::vector &inputTokens, const fastllm::GenerationConfig &generationConfig) { mainLoopLocker.lock(); if (mainLoop == nullptr) { if (mainLoop == nullptr) { mainLoop = new std::thread([](basellm *model) { while (true) { std::vector attentionMasks; std::vector positionIds; std::vector > pastKeyValues; std::vector ids; std::vector seqLens; std::vector handles; std::vector generationConfigs; LastTokensManager tokensManager; std::vector * > logits; model->dictLocker.lock(); int limit = model->tokensLimit > 0 ? model->tokensLimit : 1e9; int lenSum = 0; for (auto &it: model->responseContextDict.dicts) { if (it.second->pastKeyValues[0].first.expansionDims.size() > 0 && !it.second->isEnding) { lenSum += it.second->pastKeyValues[0].first.expansionDims[1]; } } for (int isPrompt = 1; isPrompt >= 0; isPrompt--) { int cnt = 0; if (isPrompt == 0 && seqLens.size() > 0) { continue; } if (lenSum > limit && isPrompt) { continue; } for (auto &it: model->responseContextDict.dicts) { if (it.second->isEnding) { continue; } if (isPrompt && it.second->preTokens != 0) { continue; } if (!isPrompt && it.second->preTokens == 0) { continue; } int outputLimit = it.second->generationConfig.output_token_limit; outputLimit = (outputLimit < 0 ? 128 : outputLimit); if (isPrompt && lenSum + it.second->currentTokens.size() + outputLimit > limit) { continue; } generationConfigs.push_back(it.second->generationConfig); if (it.second->generationConfig.output_logits) { it.second->resultLogits.push(new std::vector()); logits.push_back(it.second->resultLogits.back()); } else { logits.push_back(nullptr); } tokensManager.units.push_back(it.second->tokens); handles.push_back(it.first); if (it.second->preTokens == 0) { it.second->intParams["promptLen"] = it.second->currentTokens.size(); it.second->intParams["index"] = 0; } else { it.second->intParams["index"]++; } Data inputIds, attentionMask, curPositionIds; std::vector > tokens; tokens.resize(1); for (int i: it.second->currentTokens) { tokens[0].push_back(i); } model->FillLLMInputs(tokens, it.second->intParams, inputIds, attentionMask, curPositionIds); seqLens.push_back(inputIds.Count(0)); for (int i = 0; i < inputIds.Count(0); i++) { ids.push_back(((float *) inputIds.cpuData)[i]); } if (attentionMask.dims.size() == 0) { attentionMasks.push_back(nullptr); } else { attentionMasks.push_back(new Data()); attentionMasks.back()->CopyFrom(attentionMask); } if (curPositionIds.dims.size() == 0) { positionIds.push_back(nullptr); } else { positionIds.push_back(new Data()); positionIds.back()->CopyFrom(curPositionIds); } it.second->preTokens += seqLens.back(); for (int i = 0; i < model->block_cnt; i++) { pastKeyValues.push_back(std::make_pair(&it.second->pastKeyValues[i].first, &it.second->pastKeyValues[i].second)); } if (isPrompt) { cnt += it.second->currentTokens.size(); break; } } } if (seqLens.size() > 0) { std::vector > *pastKeyValue1; if (seqLens.size() == 1) { pastKeyValue1 = &model->responseContextDict.dicts[handles[0]]->pastKeyValues; } model->dictLocker.unlock(); #ifdef USE_CUDA FastllmCudaClearBigBuffer(); #endif Data inputIds = Data(DataType::FLOAT32, {1, (int) ids.size()}, ids); std::vector ret; auto st = std::chrono::system_clock::now(); //ClearProfiler(); if (seqLens.size() > 1) { ret = model->ForwardBatch(seqLens.size(), inputIds, attentionMasks, positionIds, seqLens, pastKeyValues, generationConfigs, tokensManager, &logits); } else { ret = std::vector {model->Forward(inputIds, attentionMasks[0] == nullptr ? Data() : *attentionMasks[0], *positionIds[0], *pastKeyValue1, generationConfigs[0], tokensManager, logits[0])}; } //PrintProfiler(); /* static int tot = 0; printf("len = %d, spend = %f s.\n", (int)seqLens.size(), GetSpan(st, std::chrono::system_clock::now())); tot += (int)seqLens.size(); printf("tot = %d\n", tot); */ model->dictLocker.lock(); for (int i = 0; i < handles.size(); i++) { auto &it = *model->responseContextDict.dicts.find(handles[i]); int curRet = ret[i]; if (curRet == model->eos_token_id) { it.second->isEnding = true; } else { it.second->currentTokens = std::vector{curRet}; it.second->resultTokenQueue.push(curRet); it.second->tokens.Push(curRet); it.second->curTokens++; if (it.second->curTokens == it.second->generationConfig.output_token_limit) { it.second->isEnding = true; } } } } for (int i = 0; i < attentionMasks.size(); i++) { delete attentionMasks[i]; } for (int i = 0; i < positionIds.size(); i++) { delete positionIds[i]; } model->dictLocker.unlock(); MySleep(0); } }, this); } } mainLoopLocker.unlock(); dictLocker.lock(); int handleId = responseContextDict.CreateHandle(); ResponseContext *context = responseContextDict.GetHandle(handleId); context->Init(this->block_cnt); context->currentTokens = inputTokens; context->generationConfig = generationConfig; context->tokens = LastTokensUnit(generationConfig.last_n); dictLocker.unlock(); return handleId; } int basellm::FetchResponseTokens(int handleId) { dictLocker.lock(); ResponseContext *context = responseContextDict.GetHandle(handleId); if (context == nullptr) { dictLocker.unlock(); return -1; } else { while (true) { if (context->resultTokenQueue.size() > 0) { int ret = context->resultTokenQueue.front(); context->resultTokenQueue.pop(); dictLocker.unlock(); return ret; } else { if (context->isEnding) { responseContextDict.RemoveHandle(handleId); dictLocker.unlock(); return -1; } } dictLocker.unlock(); MySleep(0); dictLocker.lock(); } } } int basellm::FetchResponseLogits(int handleId, std::vector &logits) { dictLocker.lock(); ResponseContext *context = responseContextDict.GetHandle(handleId); if (context == nullptr) { dictLocker.unlock(); return -1; } else { while (true) { if (context->resultTokenQueue.size() > 0) { int ret = context->resultTokenQueue.front(); context->resultTokenQueue.pop(); if (!context->resultLogits.empty()) { logits = *context->resultLogits.front(); delete context->resultLogits.front(); context->resultLogits.pop(); } dictLocker.unlock(); return ret; } else { if (context->isEnding) { responseContextDict.RemoveHandle(handleId); dictLocker.unlock(); return -1; } } dictLocker.unlock(); MySleep(0); dictLocker.lock(); } } } // 根据输入的tokens生成LLM推理的输入 void basellm::FillLLMInputs(std::vector > &inputTokens, const std::map ¶ms, Data &inputIds, Data &attentionMask, Data &positionIds) { } // 根据输入的tokens生成LLM推理的输入 void basellm::FillLLMInputsBatch(std::vector> &inputTokens, const std::vector> ¶ms, fastllm::Data &inputIds, fastllm::Data &attentionMask, fastllm::Data &positionIds) { } void basellm::SetAdapter(const std::string &name) { if (weight.peftDict.find(name) == weight.peftDict.end()) { ErrorInFastLLM("Can`t find adapter name: " + name); } adapterName = name; } void basellm::DisableAdapter() { adapterName = ""; } }