Commit 56215723 authored by zhouxiang's avatar zhouxiang
Browse files

1、同步到最新版本;2、增加batch推理接口;3、解决内存泄漏问题;4、修复llama系列流式输出不流畅的问题

parent 44be91d3
......@@ -34,14 +34,16 @@ namespace fastllm {
cos.resize(max_positions);
std::vector <float> invFreq;
for (int i = 0; i < rotary_dim; i += 2) {
invFreq.push_back(1.0 / pow(10000, (float)i / rotary_dim));
int base = this->bot_role.empty() ? 10000 : 10000 * rope;
invFreq.push_back(1.0 / pow(base, (float)i / rotary_dim));
}
for (int i = 0; i < max_positions; i++) {
sin[i].resize(rotary_dim);
cos[i].resize(rotary_dim);
for (int j = 0; j < invFreq.size(); j++) {
sin[i][j] = ::sin((float)i / rope * invFreq[j]);
cos[i][j] = ::cos((float)i / rope * invFreq[j]);
float scale = this->bot_role.empty() ? rope : 1.0f;
sin[i][j] = ::sin((float)i / scale * invFreq[j]);
cos[i][j] = ::cos((float)i / scale * invFreq[j]);
}
}
......@@ -59,8 +61,9 @@ namespace fastllm {
ChatGLMModel::ChatGLMModel() {
this->model_type = "chatglm";
this->bos_token_id = 130004;
this->eos_token_id = 130005;
this->bos_token_id = 130004; // V1 后期版本 bos token,可通过 config.json 覆盖
this->eos_token_id = 130005; // V1 后期版本 eos token,可通过 config.json 覆盖
this->gmask_token_id= 150001; // V1最初版本, 150528 tokens,部分 config.json 没有 gmask_token_id,因此取默认值。
this->rope = -1.0;
this->UpdateSinCos(1.0f);
......@@ -68,6 +71,33 @@ namespace fastllm {
weight.embeddingNames.insert("transformer.embedding.word_embeddings.weight");
}
void ChatGLMModel::InitParams() {
basellm::InitParams();
if (GetVersion() == 1) {
if (this->weight.dicts.find("gmask_token_id") != this->weight.dicts.end()) {
this->gmask_token_id = atoi(this->weight.dicts["gmask_token_id"].c_str());
}
} else if (GetVersion() == 2) {
this->gmask_token_id = 64790;
this->bos_token_id = 64792;
}
if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) {
UpdateSinCos(atof(this->weight.dicts["rope_ratio"].c_str()));
}
if (model_type == "chatglm3"){
int special_id = 64789;
this->mask_token_id = special_id++;
this->gmask_token_id = special_id++;
this->smask_token_id = special_id++;
this->bos_token_id = special_id++;
this->eop_token_id = special_id++;
this->system_token_id = special_id++;
this->user_token_id = special_id++;
this->assistant_token_id = special_id++;
this->observation_token_id = special_id++;
}
}
int ChatGLMModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask,
const fastllm::Data &positionIds, std::vector<std::pair<Data, Data>> &pastKeyValues,
const GenerationConfig &generationConfig, const LastTokensManager &lastTokens,
......@@ -86,9 +116,6 @@ namespace fastllm {
const GenerationConfig &generationConfig,
const LastTokensManager &lastTokens,
std::vector <std::vector <float>*> *retLogits) {
if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) {
UpdateSinCos(atof(this->weight.dicts["rope_ratio"].c_str()));
}
int maxLen = inputIds.dims[1];
Data inputEmbeddings;
Data attenInput;
......@@ -218,18 +245,16 @@ namespace fastllm {
CatDirect(pastKey, k, 1);
CatDirect(pastValue, v, 1);
std::vector<int> outputSize = {q.dims[1], q.dims[2], q.dims[0], pastKey.dims[1]};
q.Reshape({q.dims[0], q.dims[1] * q.dims[2], q.dims[3]});
PermuteSelf(q, {1, 0, 2});
//Attention(q, pastKey, pastValue, attentionMask, contextLayer, q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1);
Attention(q, pastKey, pastValue, attentionMask, contextLayer, q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1);
/*
// 1.2 Attention
// 1.2.0 q * k^T
q.Reshape({pastKey.dims[0], -1, q.dims[2]});
MatMulTransB(q, pastKey, attnProbs, 1.0 / (scale_attn * (i + 1)));
attnProbs.Reshape(outputSize);
// 1.2.1 Mask
if (attentionMask.dims.size() != 0) {
AttentionMask(attnProbs, attentionMask, -10000);
......@@ -243,6 +268,8 @@ namespace fastllm {
attnProbs.Reshape({pastValue.dims[0], -1, attnProbs.dims[2]});
MatMul(attnProbs, pastValue, contextLayer);
*/
contextLayer.Reshape({batch, num_attention_heads, maxLen, -1});
PermuteSelf(contextLayer, {2, 0, 1, 3});
contextLayer.Reshape({contextLayer.dims[0], contextLayer.dims[1], embed_dim});
......@@ -286,6 +313,17 @@ namespace fastllm {
}
Data logits, topk;
Data tempHiddenStates;
Data *lastHiddenStates;
if (maxLen > 1) {
Split(hiddenStates, 0, maxLen - 1, maxLen, tempHiddenStates);
lastHiddenStates = &tempHiddenStates;
} else {
lastHiddenStates = &hiddenStates;
}
{
auto &hiddenStates = *lastHiddenStates;
if (version == 1) {
LayerNorm(hiddenStates, weight["transformer.final_layernorm.weight"],
weight["transformer.final_layernorm.bias"], -1, hiddenStates);
......@@ -298,24 +336,26 @@ namespace fastllm {
int size = logits.dims.back();
logits.ToDevice(DataDevice::CPU);
for (int b = 0; b < batch; b++) {
int base = (maxLen - 1) * batch + b;
int base = b;
(*retLogits)[b]->resize(size);
memcpy((float*)(*retLogits)[b]->data(), ((float*)logits.cpuData) + base * size, size * logits.unitSize);
memcpy((float *) (*retLogits)[b]->data(), ((float *) logits.cpuData) + base * size,
size * logits.unitSize);
}
}
if (generationConfig.IsSimpleGreedy()) {
TopK(logits, topk, 1);
topk.ToDevice(DataDevice::CPU);
for (int b = 0; b < batch; b++) {
int base = (maxLen - 1) * batch + b;
int base = b;
lastRet.push_back((int) (((float *) topk.cpuData)[base * 2] + 1e-3));
}
} else if (!lastTokens.units.empty()) {
for (int b = 0; b < batch; b++) {
int base = (maxLen - 1) * batch + b;
int base = b;
lastRet.push_back(LLMSampling(logits, base, generationConfig, lastTokens.units[b]));
}
}
}
return lastRet;
}
......@@ -329,9 +369,6 @@ namespace fastllm {
const std::vector <GenerationConfig> &generationConfigs,
const LastTokensManager &lastTokens,
std::vector <std::vector <float>*> *retLogits) {
if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) {
UpdateSinCos(atof(this->weight.dicts["rope_ratio"].c_str()));
}
int seqLen = inputIds.dims[1];
sinData.ToDevice(DataDevice::CUDA);
cosData.ToDevice(DataDevice::CUDA);
......@@ -344,7 +381,6 @@ namespace fastllm {
weightPre = "transformer.encoder.layers.";
weightMiddle = ".self_attention";
}
Data inputEmbeddings;
Data inputIdsPermute;
Permute(inputIds, {1, 0}, inputIdsPermute);
......@@ -352,7 +388,6 @@ namespace fastllm {
".word_embeddings.weight"], inputEmbeddings);
Data &hiddenStates = inputEmbeddings;
hiddenStates.ToDevice(DataDevice::CUDA);
Data attenInput;
Data qkv, q, k, v;
Data attnOutput;
......@@ -365,7 +400,6 @@ namespace fastllm {
curKs.resize(batch);
curVs.resize(batch);
curQs.resize(batch);
bool all1 = true;
for (int i = 0; i < batch; i++) {
all1 &= (seqLens[i] == 1);
......@@ -392,7 +426,6 @@ namespace fastllm {
std::vector <std::vector <int> > outputSizes;
outputSizes.resize(batch);
for (int i = 0; i < block_cnt; i++) {
ApplyDeviceMap(this->deviceMap, i + 1, block_cnt);
if (version == 1) {
......@@ -451,7 +484,6 @@ namespace fastllm {
Data contextLayer = Data(DataType::FLOAT32);
int total = 0;
if (all1 && batch > 1) {
for (int b = 0; b < batch; b++) {
pointersK[b] = (&curKs[b]);
......@@ -482,6 +514,7 @@ namespace fastllm {
total += seqLens[b];
}
}
for (int b = 0; b < batch; b++) {
auto &q = curQs[b], &k = curKs[b], &v = curVs[b];
Data &pastKey = *pastKeyValues[b * block_cnt + i].first, &pastValue = *pastKeyValues[b * block_cnt +
......@@ -717,8 +750,6 @@ namespace fastllm {
attentionMask.ToDevice(DataDevice::CPU);
positionIds.ToDevice(DataDevice::CPU);
int gmask_token_id = this->weight.dicts.find("gmask_token_id") != this->weight.dicts.end() ?
atoi(this->weight.dicts["gmask_token_id"].c_str()) : 130001;
int index = params.find("index")->second;
int promptLen = params.find("promptLen")->second;
......@@ -728,9 +759,9 @@ namespace fastllm {
ids.push_back(gmask_token_id);
ids.push_back(bos_token_id);
} else if (GetVersion() == 2) {
if (ids.size() < 2 || ids[0] != 64790 || ids[1] != 64792) {
ids.insert(ids.begin(), 64792);
ids.insert(ids.begin(), 64790);
if (ids.size() < 2 || ids[0] != this->gmask_token_id || ids[1] != this->bos_token_id) {
ids.insert(ids.begin(), this->bos_token_id);
ids.insert(ids.begin(), this->gmask_token_id);
}
}
}
......@@ -780,8 +811,6 @@ namespace fastllm {
int batch = inputTokens.size();
int index = params[0].find("index")->second;
if (index == 0) {
int gmask_token_id = this->weight.dicts.find("gmask_token_id") != this->weight.dicts.end() ?
atoi(this->weight.dicts["gmask_token_id"].c_str()) : 130001;
std::vector<int> seqLens;
seqLens.resize(batch);
int maxLen = 0;
......@@ -820,8 +849,8 @@ namespace fastllm {
} else {
auto &tokens = inputTokens[i];
int len = tokens.size(), base = maxLen - 2 - len;
ids[i * maxLen + base] = 64790;
ids[i * maxLen + base + 1] = 64792;
ids[i * maxLen + base] = gmask_token_id;
ids[i * maxLen + base + 1] = bos_token_id;
for (int j = 0; j < len; j++) {
ids[i * maxLen + base + 2 + j] = tokens[j];
}
......@@ -894,28 +923,32 @@ namespace fastllm {
}
std::string ChatGLMModel::MakeInput(const std::string &history, int round, const std::string &input) {
if (this->bot_role != "") {
return (round == 0 ? pre_prompt : history) + user_role + input + bot_role;
} else {
if (GetVersion() == 2)
round++;
if (round == 0 && GetVersion() == 1) {
return input;
} else {
#if defined(_WIN32) or defined(_WIN64)
std::vector <uint8_t> vask = {233, 151, 174, 239, 188, 154, 0};
std::vector <uint8_t> vans = {231, 173, 148, 239, 188, 154, 0};
std::string sask = (char*)vask.data();
std::string sans = (char*)vans.data();
return (history + ("[Round " + std::to_string(round) + "]\n\n" + sask + input + "\n\n" + sans));
return history + ("[Round " + std::to_string(round) + u8"]\n\n问:" + input + u8"\n\n答:");
#else
return history + ("[Round " + std::to_string(round) + "]\n\n问:" + input + "\n\n答:");
#endif
}
}
}
std::string ChatGLMModel::MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output) {
if (this->bot_role != "") {
return (round == 0 ? pre_prompt : history) + user_role + input + bot_role + output + history_sep;
}
if (GetVersion() == 2)
round++;
#if defined(_WIN32) or defined(_WIN64)
std::vector <uint8_t> vask = {233, 151, 174, 239, 188, 154, 0};
std::vector <uint8_t> vans = {231, 173, 148, 239, 188, 154, 0};
std::string sask = (char*)vask.data();
std::string sans = (char*)vans.data();
return (history + ("[Round " + std::to_string(round) + "]\n\n" + sask + input + "\n\n" + sans + output + "\n"));
return (history + ("[Round " + std::to_string(round) + u8"]\n\n问:" + input + u8"\n\n答:" + output + "\n"));
#else
return (history + ("[Round " + std::to_string(round) + "]\n\n问:" + input + "\n\n答:" + output + "\n\n"));
#endif
......
This diff is collapsed.
This diff is collapsed.
......@@ -207,8 +207,8 @@ namespace fastllm {
RuntimeResult retCb,
const GenerationConfig &generationConfig) {
#ifdef PY_API
size_t pos = input.find_last_of("time_stamp:");
std::string prompt = (generationConfig.enable_hash_id && pos != std::string::npos)? input.substr(0, pos-10):input;
size_t pos = input.rfind("time_stamp:");
std::string prompt = (generationConfig.enable_hash_id && pos != -1)? input.substr(0, pos):input;
size_t hash_id = std::hash<std::string>{}(input);
Data inputIds = this->weight.tokenizer.Encode(prompt);
#else
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
uvicorn==0.23.2
pydantic==2.5.1
fastapi==0.103.1
sse_starlette
openaiopenai==0.28
This diff is collapsed.
......@@ -14,5 +14,5 @@ if __name__ == "__main__":
except:
pass
dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
exportPath = sys.argv[1] if len(sys.argv) >= 2 else "baichuan-13b-' + dtype + '.flm"
exportPath = sys.argv[1] if len(sys.argv) >= 2 else "baichuan-13b-" + dtype + ".flm"
torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment