Commit 56215723 authored by zhouxiang's avatar zhouxiang
Browse files

1、同步到最新版本;2、增加batch推理接口;3、解决内存泄漏问题;4、修复llama系列流式输出不流畅的问题

parent 44be91d3
...@@ -34,14 +34,16 @@ namespace fastllm { ...@@ -34,14 +34,16 @@ namespace fastllm {
cos.resize(max_positions); cos.resize(max_positions);
std::vector <float> invFreq; std::vector <float> invFreq;
for (int i = 0; i < rotary_dim; i += 2) { for (int i = 0; i < rotary_dim; i += 2) {
invFreq.push_back(1.0 / pow(10000, (float)i / rotary_dim)); int base = this->bot_role.empty() ? 10000 : 10000 * rope;
invFreq.push_back(1.0 / pow(base, (float)i / rotary_dim));
} }
for (int i = 0; i < max_positions; i++) { for (int i = 0; i < max_positions; i++) {
sin[i].resize(rotary_dim); sin[i].resize(rotary_dim);
cos[i].resize(rotary_dim); cos[i].resize(rotary_dim);
for (int j = 0; j < invFreq.size(); j++) { for (int j = 0; j < invFreq.size(); j++) {
sin[i][j] = ::sin((float)i / rope * invFreq[j]); float scale = this->bot_role.empty() ? rope : 1.0f;
cos[i][j] = ::cos((float)i / rope * invFreq[j]); sin[i][j] = ::sin((float)i / scale * invFreq[j]);
cos[i][j] = ::cos((float)i / scale * invFreq[j]);
} }
} }
...@@ -59,8 +61,9 @@ namespace fastllm { ...@@ -59,8 +61,9 @@ namespace fastllm {
ChatGLMModel::ChatGLMModel() { ChatGLMModel::ChatGLMModel() {
this->model_type = "chatglm"; this->model_type = "chatglm";
this->bos_token_id = 130004; this->bos_token_id = 130004; // V1 后期版本 bos token,可通过 config.json 覆盖
this->eos_token_id = 130005; this->eos_token_id = 130005; // V1 后期版本 eos token,可通过 config.json 覆盖
this->gmask_token_id= 150001; // V1最初版本, 150528 tokens,部分 config.json 没有 gmask_token_id,因此取默认值。
this->rope = -1.0; this->rope = -1.0;
this->UpdateSinCos(1.0f); this->UpdateSinCos(1.0f);
...@@ -68,6 +71,33 @@ namespace fastllm { ...@@ -68,6 +71,33 @@ namespace fastllm {
weight.embeddingNames.insert("transformer.embedding.word_embeddings.weight"); weight.embeddingNames.insert("transformer.embedding.word_embeddings.weight");
} }
void ChatGLMModel::InitParams() {
basellm::InitParams();
if (GetVersion() == 1) {
if (this->weight.dicts.find("gmask_token_id") != this->weight.dicts.end()) {
this->gmask_token_id = atoi(this->weight.dicts["gmask_token_id"].c_str());
}
} else if (GetVersion() == 2) {
this->gmask_token_id = 64790;
this->bos_token_id = 64792;
}
if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) {
UpdateSinCos(atof(this->weight.dicts["rope_ratio"].c_str()));
}
if (model_type == "chatglm3"){
int special_id = 64789;
this->mask_token_id = special_id++;
this->gmask_token_id = special_id++;
this->smask_token_id = special_id++;
this->bos_token_id = special_id++;
this->eop_token_id = special_id++;
this->system_token_id = special_id++;
this->user_token_id = special_id++;
this->assistant_token_id = special_id++;
this->observation_token_id = special_id++;
}
}
int ChatGLMModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask, int ChatGLMModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask,
const fastllm::Data &positionIds, std::vector<std::pair<Data, Data>> &pastKeyValues, const fastllm::Data &positionIds, std::vector<std::pair<Data, Data>> &pastKeyValues,
const GenerationConfig &generationConfig, const LastTokensManager &lastTokens, const GenerationConfig &generationConfig, const LastTokensManager &lastTokens,
...@@ -86,9 +116,6 @@ namespace fastllm { ...@@ -86,9 +116,6 @@ namespace fastllm {
const GenerationConfig &generationConfig, const GenerationConfig &generationConfig,
const LastTokensManager &lastTokens, const LastTokensManager &lastTokens,
std::vector <std::vector <float>*> *retLogits) { std::vector <std::vector <float>*> *retLogits) {
if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) {
UpdateSinCos(atof(this->weight.dicts["rope_ratio"].c_str()));
}
int maxLen = inputIds.dims[1]; int maxLen = inputIds.dims[1];
Data inputEmbeddings; Data inputEmbeddings;
Data attenInput; Data attenInput;
...@@ -218,18 +245,16 @@ namespace fastllm { ...@@ -218,18 +245,16 @@ namespace fastllm {
CatDirect(pastKey, k, 1); CatDirect(pastKey, k, 1);
CatDirect(pastValue, v, 1); CatDirect(pastValue, v, 1);
std::vector<int> outputSize = {q.dims[1], q.dims[2], q.dims[0], pastKey.dims[1]}; std::vector<int> outputSize = {q.dims[1], q.dims[2], q.dims[0], pastKey.dims[1]};
q.Reshape({q.dims[0], q.dims[1] * q.dims[2], q.dims[3]}); q.Reshape({q.dims[0], q.dims[1] * q.dims[2], q.dims[3]});
PermuteSelf(q, {1, 0, 2}); PermuteSelf(q, {1, 0, 2});
//Attention(q, pastKey, pastValue, attentionMask, contextLayer, q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1); Attention(q, pastKey, pastValue, attentionMask, contextLayer, q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1);
/*
// 1.2 Attention // 1.2 Attention
// 1.2.0 q * k^T // 1.2.0 q * k^T
q.Reshape({pastKey.dims[0], -1, q.dims[2]}); q.Reshape({pastKey.dims[0], -1, q.dims[2]});
MatMulTransB(q, pastKey, attnProbs, 1.0 / (scale_attn * (i + 1))); MatMulTransB(q, pastKey, attnProbs, 1.0 / (scale_attn * (i + 1)));
attnProbs.Reshape(outputSize); attnProbs.Reshape(outputSize);
// 1.2.1 Mask // 1.2.1 Mask
if (attentionMask.dims.size() != 0) { if (attentionMask.dims.size() != 0) {
AttentionMask(attnProbs, attentionMask, -10000); AttentionMask(attnProbs, attentionMask, -10000);
...@@ -243,6 +268,8 @@ namespace fastllm { ...@@ -243,6 +268,8 @@ namespace fastllm {
attnProbs.Reshape({pastValue.dims[0], -1, attnProbs.dims[2]}); attnProbs.Reshape({pastValue.dims[0], -1, attnProbs.dims[2]});
MatMul(attnProbs, pastValue, contextLayer); MatMul(attnProbs, pastValue, contextLayer);
*/
contextLayer.Reshape({batch, num_attention_heads, maxLen, -1}); contextLayer.Reshape({batch, num_attention_heads, maxLen, -1});
PermuteSelf(contextLayer, {2, 0, 1, 3}); PermuteSelf(contextLayer, {2, 0, 1, 3});
contextLayer.Reshape({contextLayer.dims[0], contextLayer.dims[1], embed_dim}); contextLayer.Reshape({contextLayer.dims[0], contextLayer.dims[1], embed_dim});
...@@ -286,6 +313,17 @@ namespace fastllm { ...@@ -286,6 +313,17 @@ namespace fastllm {
} }
Data logits, topk; Data logits, topk;
Data tempHiddenStates;
Data *lastHiddenStates;
if (maxLen > 1) {
Split(hiddenStates, 0, maxLen - 1, maxLen, tempHiddenStates);
lastHiddenStates = &tempHiddenStates;
} else {
lastHiddenStates = &hiddenStates;
}
{
auto &hiddenStates = *lastHiddenStates;
if (version == 1) { if (version == 1) {
LayerNorm(hiddenStates, weight["transformer.final_layernorm.weight"], LayerNorm(hiddenStates, weight["transformer.final_layernorm.weight"],
weight["transformer.final_layernorm.bias"], -1, hiddenStates); weight["transformer.final_layernorm.bias"], -1, hiddenStates);
...@@ -298,24 +336,26 @@ namespace fastllm { ...@@ -298,24 +336,26 @@ namespace fastllm {
int size = logits.dims.back(); int size = logits.dims.back();
logits.ToDevice(DataDevice::CPU); logits.ToDevice(DataDevice::CPU);
for (int b = 0; b < batch; b++) { for (int b = 0; b < batch; b++) {
int base = (maxLen - 1) * batch + b; int base = b;
(*retLogits)[b]->resize(size); (*retLogits)[b]->resize(size);
memcpy((float*)(*retLogits)[b]->data(), ((float*)logits.cpuData) + base * size, size * logits.unitSize); memcpy((float *) (*retLogits)[b]->data(), ((float *) logits.cpuData) + base * size,
size * logits.unitSize);
} }
} }
if (generationConfig.IsSimpleGreedy()) { if (generationConfig.IsSimpleGreedy()) {
TopK(logits, topk, 1); TopK(logits, topk, 1);
topk.ToDevice(DataDevice::CPU); topk.ToDevice(DataDevice::CPU);
for (int b = 0; b < batch; b++) { for (int b = 0; b < batch; b++) {
int base = (maxLen - 1) * batch + b; int base = b;
lastRet.push_back((int) (((float *) topk.cpuData)[base * 2] + 1e-3)); lastRet.push_back((int) (((float *) topk.cpuData)[base * 2] + 1e-3));
} }
} else if (!lastTokens.units.empty()) { } else if (!lastTokens.units.empty()) {
for (int b = 0; b < batch; b++) { for (int b = 0; b < batch; b++) {
int base = (maxLen - 1) * batch + b; int base = b;
lastRet.push_back(LLMSampling(logits, base, generationConfig, lastTokens.units[b])); lastRet.push_back(LLMSampling(logits, base, generationConfig, lastTokens.units[b]));
} }
} }
}
return lastRet; return lastRet;
} }
...@@ -329,9 +369,6 @@ namespace fastllm { ...@@ -329,9 +369,6 @@ namespace fastllm {
const std::vector <GenerationConfig> &generationConfigs, const std::vector <GenerationConfig> &generationConfigs,
const LastTokensManager &lastTokens, const LastTokensManager &lastTokens,
std::vector <std::vector <float>*> *retLogits) { std::vector <std::vector <float>*> *retLogits) {
if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) {
UpdateSinCos(atof(this->weight.dicts["rope_ratio"].c_str()));
}
int seqLen = inputIds.dims[1]; int seqLen = inputIds.dims[1];
sinData.ToDevice(DataDevice::CUDA); sinData.ToDevice(DataDevice::CUDA);
cosData.ToDevice(DataDevice::CUDA); cosData.ToDevice(DataDevice::CUDA);
...@@ -344,7 +381,6 @@ namespace fastllm { ...@@ -344,7 +381,6 @@ namespace fastllm {
weightPre = "transformer.encoder.layers."; weightPre = "transformer.encoder.layers.";
weightMiddle = ".self_attention"; weightMiddle = ".self_attention";
} }
Data inputEmbeddings; Data inputEmbeddings;
Data inputIdsPermute; Data inputIdsPermute;
Permute(inputIds, {1, 0}, inputIdsPermute); Permute(inputIds, {1, 0}, inputIdsPermute);
...@@ -352,7 +388,6 @@ namespace fastllm { ...@@ -352,7 +388,6 @@ namespace fastllm {
".word_embeddings.weight"], inputEmbeddings); ".word_embeddings.weight"], inputEmbeddings);
Data &hiddenStates = inputEmbeddings; Data &hiddenStates = inputEmbeddings;
hiddenStates.ToDevice(DataDevice::CUDA); hiddenStates.ToDevice(DataDevice::CUDA);
Data attenInput; Data attenInput;
Data qkv, q, k, v; Data qkv, q, k, v;
Data attnOutput; Data attnOutput;
...@@ -365,7 +400,6 @@ namespace fastllm { ...@@ -365,7 +400,6 @@ namespace fastllm {
curKs.resize(batch); curKs.resize(batch);
curVs.resize(batch); curVs.resize(batch);
curQs.resize(batch); curQs.resize(batch);
bool all1 = true; bool all1 = true;
for (int i = 0; i < batch; i++) { for (int i = 0; i < batch; i++) {
all1 &= (seqLens[i] == 1); all1 &= (seqLens[i] == 1);
...@@ -392,7 +426,6 @@ namespace fastllm { ...@@ -392,7 +426,6 @@ namespace fastllm {
std::vector <std::vector <int> > outputSizes; std::vector <std::vector <int> > outputSizes;
outputSizes.resize(batch); outputSizes.resize(batch);
for (int i = 0; i < block_cnt; i++) { for (int i = 0; i < block_cnt; i++) {
ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); ApplyDeviceMap(this->deviceMap, i + 1, block_cnt);
if (version == 1) { if (version == 1) {
...@@ -451,7 +484,6 @@ namespace fastllm { ...@@ -451,7 +484,6 @@ namespace fastllm {
Data contextLayer = Data(DataType::FLOAT32); Data contextLayer = Data(DataType::FLOAT32);
int total = 0; int total = 0;
if (all1 && batch > 1) { if (all1 && batch > 1) {
for (int b = 0; b < batch; b++) { for (int b = 0; b < batch; b++) {
pointersK[b] = (&curKs[b]); pointersK[b] = (&curKs[b]);
...@@ -482,6 +514,7 @@ namespace fastllm { ...@@ -482,6 +514,7 @@ namespace fastllm {
total += seqLens[b]; total += seqLens[b];
} }
} }
for (int b = 0; b < batch; b++) { for (int b = 0; b < batch; b++) {
auto &q = curQs[b], &k = curKs[b], &v = curVs[b]; auto &q = curQs[b], &k = curKs[b], &v = curVs[b];
Data &pastKey = *pastKeyValues[b * block_cnt + i].first, &pastValue = *pastKeyValues[b * block_cnt + Data &pastKey = *pastKeyValues[b * block_cnt + i].first, &pastValue = *pastKeyValues[b * block_cnt +
...@@ -717,8 +750,6 @@ namespace fastllm { ...@@ -717,8 +750,6 @@ namespace fastllm {
attentionMask.ToDevice(DataDevice::CPU); attentionMask.ToDevice(DataDevice::CPU);
positionIds.ToDevice(DataDevice::CPU); positionIds.ToDevice(DataDevice::CPU);
int gmask_token_id = this->weight.dicts.find("gmask_token_id") != this->weight.dicts.end() ?
atoi(this->weight.dicts["gmask_token_id"].c_str()) : 130001;
int index = params.find("index")->second; int index = params.find("index")->second;
int promptLen = params.find("promptLen")->second; int promptLen = params.find("promptLen")->second;
...@@ -728,9 +759,9 @@ namespace fastllm { ...@@ -728,9 +759,9 @@ namespace fastllm {
ids.push_back(gmask_token_id); ids.push_back(gmask_token_id);
ids.push_back(bos_token_id); ids.push_back(bos_token_id);
} else if (GetVersion() == 2) { } else if (GetVersion() == 2) {
if (ids.size() < 2 || ids[0] != 64790 || ids[1] != 64792) { if (ids.size() < 2 || ids[0] != this->gmask_token_id || ids[1] != this->bos_token_id) {
ids.insert(ids.begin(), 64792); ids.insert(ids.begin(), this->bos_token_id);
ids.insert(ids.begin(), 64790); ids.insert(ids.begin(), this->gmask_token_id);
} }
} }
} }
...@@ -780,8 +811,6 @@ namespace fastllm { ...@@ -780,8 +811,6 @@ namespace fastllm {
int batch = inputTokens.size(); int batch = inputTokens.size();
int index = params[0].find("index")->second; int index = params[0].find("index")->second;
if (index == 0) { if (index == 0) {
int gmask_token_id = this->weight.dicts.find("gmask_token_id") != this->weight.dicts.end() ?
atoi(this->weight.dicts["gmask_token_id"].c_str()) : 130001;
std::vector<int> seqLens; std::vector<int> seqLens;
seqLens.resize(batch); seqLens.resize(batch);
int maxLen = 0; int maxLen = 0;
...@@ -820,8 +849,8 @@ namespace fastllm { ...@@ -820,8 +849,8 @@ namespace fastllm {
} else { } else {
auto &tokens = inputTokens[i]; auto &tokens = inputTokens[i];
int len = tokens.size(), base = maxLen - 2 - len; int len = tokens.size(), base = maxLen - 2 - len;
ids[i * maxLen + base] = 64790; ids[i * maxLen + base] = gmask_token_id;
ids[i * maxLen + base + 1] = 64792; ids[i * maxLen + base + 1] = bos_token_id;
for (int j = 0; j < len; j++) { for (int j = 0; j < len; j++) {
ids[i * maxLen + base + 2 + j] = tokens[j]; ids[i * maxLen + base + 2 + j] = tokens[j];
} }
...@@ -894,28 +923,32 @@ namespace fastllm { ...@@ -894,28 +923,32 @@ namespace fastllm {
} }
std::string ChatGLMModel::MakeInput(const std::string &history, int round, const std::string &input) { std::string ChatGLMModel::MakeInput(const std::string &history, int round, const std::string &input) {
if (this->bot_role != "") {
return (round == 0 ? pre_prompt : history) + user_role + input + bot_role;
} else {
if (GetVersion() == 2)
round++;
if (round == 0 && GetVersion() == 1) { if (round == 0 && GetVersion() == 1) {
return input; return input;
} else { } else {
#if defined(_WIN32) or defined(_WIN64) #if defined(_WIN32) or defined(_WIN64)
std::vector <uint8_t> vask = {233, 151, 174, 239, 188, 154, 0}; return history + ("[Round " + std::to_string(round) + u8"]\n\n问:" + input + u8"\n\n答:");
std::vector <uint8_t> vans = {231, 173, 148, 239, 188, 154, 0};
std::string sask = (char*)vask.data();
std::string sans = (char*)vans.data();
return (history + ("[Round " + std::to_string(round) + "]\n\n" + sask + input + "\n\n" + sans));
#else #else
return history + ("[Round " + std::to_string(round) + "]\n\n问:" + input + "\n\n答:"); return history + ("[Round " + std::to_string(round) + "]\n\n问:" + input + "\n\n答:");
#endif #endif
} }
} }
}
std::string ChatGLMModel::MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output) { std::string ChatGLMModel::MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output) {
if (this->bot_role != "") {
return (round == 0 ? pre_prompt : history) + user_role + input + bot_role + output + history_sep;
}
if (GetVersion() == 2)
round++;
#if defined(_WIN32) or defined(_WIN64) #if defined(_WIN32) or defined(_WIN64)
std::vector <uint8_t> vask = {233, 151, 174, 239, 188, 154, 0}; return (history + ("[Round " + std::to_string(round) + u8"]\n\n问:" + input + u8"\n\n答:" + output + "\n"));
std::vector <uint8_t> vans = {231, 173, 148, 239, 188, 154, 0};
std::string sask = (char*)vask.data();
std::string sans = (char*)vans.data();
return (history + ("[Round " + std::to_string(round) + "]\n\n" + sask + input + "\n\n" + sans + output + "\n"));
#else #else
return (history + ("[Round " + std::to_string(round) + "]\n\n问:" + input + "\n\n答:" + output + "\n\n")); return (history + ("[Round " + std::to_string(round) + "]\n\n问:" + input + "\n\n答:" + output + "\n\n"));
#endif #endif
......
This diff is collapsed.
This diff is collapsed.
...@@ -207,8 +207,8 @@ namespace fastllm { ...@@ -207,8 +207,8 @@ namespace fastllm {
RuntimeResult retCb, RuntimeResult retCb,
const GenerationConfig &generationConfig) { const GenerationConfig &generationConfig) {
#ifdef PY_API #ifdef PY_API
size_t pos = input.find_last_of("time_stamp:"); size_t pos = input.rfind("time_stamp:");
std::string prompt = (generationConfig.enable_hash_id && pos != std::string::npos)? input.substr(0, pos-10):input; std::string prompt = (generationConfig.enable_hash_id && pos != -1)? input.substr(0, pos):input;
size_t hash_id = std::hash<std::string>{}(input); size_t hash_id = std::hash<std::string>{}(input);
Data inputIds = this->weight.tokenizer.Encode(prompt); Data inputIds = this->weight.tokenizer.Encode(prompt);
#else #else
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
uvicorn==0.23.2
pydantic==2.5.1
fastapi==0.103.1
sse_starlette
openaiopenai==0.28
This diff is collapsed.
...@@ -14,5 +14,5 @@ if __name__ == "__main__": ...@@ -14,5 +14,5 @@ if __name__ == "__main__":
except: except:
pass pass
dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16" dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
exportPath = sys.argv[1] if len(sys.argv) >= 2 else "baichuan-13b-' + dtype + '.flm" exportPath = sys.argv[1] if len(sys.argv) >= 2 else "baichuan-13b-" + dtype + ".flm"
torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype) torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment