// // Created by siemon on 8/9/23. // #include "utils.h" #include "qwen.h" #include #include #include #include #include #include #ifdef USE_CUDA #include "fastllm-cuda.cuh" #endif namespace fastllm { extern double GetSpan(std::chrono::system_clock::time_point time1, std::chrono::system_clock::time_point time2); QWenModel::QWenModel() { this->model_type = "qwen"; this->pre_prompt = "You are a helpful assistant."; this->user_role = "user"; this->bot_role = "assistant"; embed_dim = 4096; num_attention_heads = 32; head_dim = embed_dim / num_attention_heads; block_cnt = 32; rotary_dim = 128; seq_length = 2048; use_log_attn = true; ntk_alpha = 1.f; UpdateRotaryPosEmb(ntk_alpha); if (use_log_attn) { logn_list = Data(DataType::FLOAT32); logn_list.Resize({1, max_positions, 1, 1}); logn_list.Allocate(); float *logn = (float *) logn_list.cpuData; for (int i = 0; i < seq_length; i++) { logn[i] = 1; } for (int i = seq_length; i < max_positions; i++) { logn[i] = std::log(i) / std::log(seq_length); } } weight.embeddingNames.insert("transformer.wte.weight"); } int QWenModel::Forward(const Data &inputIds, const Data &attentionMask, const Data &positionIds, std::vector > &pastKeyValues, const GenerationConfig &generationConfig, const LastTokensManager &lastTokens, std::vector *logits) { std::vector *> batchLogits; batchLogits.push_back(logits); return ForwardBatch(1, inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, lastTokens, &batchLogits)[0]; } std::vector QWenModel::ForwardBatch(int batch, const Data &inputIds, const Data &attentionMask, const Data &positionIds, std::vector > &pastKeyValues, const GenerationConfig &generationConfig, const LastTokensManager &lastTokens, std::vector *> *retLogits) { int maxLen = inputIds.dims[1]; Data hiddenStates; Data attnInput, attnOutput; Data query, key, value; Data attnWeights, attnLastOutput; Data a1, a2, mlpOutput; // printf("input id: "); // for (int i = 0; i < inputIds.Count(0); i++) { // printf("%d ", (int )((float *) inputIds.cpuData)[i]); // } // printf("\n"); Embedding(inputIds, this->weight["transformer.wte.weight"], hiddenStates); for (int i = 0; i < this->block_cnt; i++) { ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); int seqlen = hiddenStates.dims[1]; std::string ln_1_name = "transformer.h." + std::to_string(i) + ".ln_1.weight"; std::string attn_weight_name = "transformer.h." + std::to_string(i) + ".attn.c_attn.weight"; std::string attn_bias_name = "transformer.h." + std::to_string(i) + ".attn.c_attn.bias"; RMSNorm(hiddenStates, weight[ln_1_name], 1e-6, attnInput); Linear(attnInput, weight[attn_weight_name], weight[attn_bias_name], attnOutput); // attnOutput [batch, seqlen, embed_dim * 3] Split(attnOutput, 2, 0, embed_dim, query); Split(attnOutput, 2, embed_dim, 2 * embed_dim, key); Split(attnOutput, 2, embed_dim * 2, embed_dim * 3, value); query.Reshape({query.dims[0], query.dims[1], num_attention_heads, head_dim}); key.Reshape({key.dims[0], key.dims[1], num_attention_heads, head_dim}); value.Reshape({value.dims[0], value.dims[1], num_attention_heads, head_dim}); Data &pastKey = pastKeyValues[i].first, &pastValue = pastKeyValues[i].second; if (pastKey.dims.empty()) { // 计算new_ntk_alpha float context_value = std::log2((float) seqlen / seq_length) + 1; float new_ntk_alpha = std::max(std::pow(2, std::ceil(context_value) - 1), 1.); if (new_ntk_alpha != ntk_alpha) { UpdateRotaryPosEmb(new_ntk_alpha); } } LlamaRotatePosition2D(query, positionIds, sinData, cosData, rotary_dim); LlamaRotatePosition2D(key, positionIds, sinData, cosData, rotary_dim); if (use_log_attn) { ApplyLognAttn(query, logn_list, positionIds); } PermuteSelf(query, {0, 2, 1, 3}); PermuteSelf(key, {0, 2, 1, 3}); PermuteSelf(value, {0, 2, 1, 3}); std::vector qkvSize = {batch * num_attention_heads, seqlen, -1}; query.Reshape(qkvSize); key.Reshape(qkvSize); value.Reshape(qkvSize); int unitLen = 64; #ifdef USE_CUDA unitLen = 128; #endif while ((pastKey.dims.size() == 0 && (pastKey.expansionDims.size() == 0 || key.dims[1] > pastKey.expansionDims[1])) || (pastKey.dims.size() > 0 && pastKey.dims[1] + key.dims[1] > pastKey.expansionDims[1])) { std::vector newDims; if (pastKey.Count(0) == 0 || pastKey.dims.size() == 0) { newDims = std::vector {key.dims[0], ((key.dims[1] - 1) / unitLen + 1) * unitLen, key.dims[2]}; } else { newDims = pastKey.dims; newDims[1] += ((key.dims[1] - 1) / unitLen + 1) * unitLen; } pastKey.Expansion(newDims); } while ((pastValue.dims.size() == 0 && (pastValue.expansionDims.size() == 0 || value.dims[1] > pastValue.expansionDims[1])) || (pastValue.dims.size() > 0 && pastValue.dims[1] + value.dims[1] > pastValue.expansionDims[1])) { std::vector newDims; if (pastValue.Count(0) == 0 || pastValue.dims.size() == 0) { newDims = std::vector {value.dims[0], ((value.dims[1] - 1) / unitLen + 1) * unitLen, value.dims[2]}; } else { newDims = pastValue.dims; newDims[1] += ((value.dims[1] - 1) / unitLen + 1) * unitLen; } pastValue.Expansion(newDims); } CatDirect(pastKey, key, 1); CatDirect(pastValue, value, 1); // Attention MatMulTransB(query, pastKey, attnWeights, 1.0 / sqrt(head_dim)); attnWeights.Reshape({1, attnWeights.dims[0], attnWeights.dims[1], attnWeights.dims[2]}); if (!attentionMask.dims.empty()) { AttentionMask(attnWeights, attentionMask, -10000); } Softmax(attnWeights, attnWeights, -1); MatMul(attnWeights, pastValue, attnOutput); attnOutput.Reshape({attnOutput.dims[1], attnOutput.dims[2], attnOutput.dims[3]}); PermuteSelf(attnOutput, {1, 0, 2}); attnOutput.Reshape({seqlen, batch, -1}); PermuteSelf(attnOutput, {1, 0, 2}); std::string proj_weight_name = "transformer.h." + std::to_string(i) + ".attn.c_proj.weight"; Linear(attnOutput, weight[proj_weight_name], Data(), attnLastOutput); AddTo(hiddenStates, attnLastOutput); std::string ln_2_name = "transformer.h." + std::to_string(i) + ".ln_2.weight"; RMSNorm(hiddenStates, weight[ln_2_name], 1e-6, attnInput); std::string mlp_w1_weight_name = "transformer.h." + std::to_string(i) + ".mlp.w1.weight"; std::string mlp_w2_weight_name = "transformer.h." + std::to_string(i) + ".mlp.w2.weight"; std::string mlp_proj_weight_name = "transformer.h." + std::to_string(i) + ".mlp.c_proj.weight"; Linear(attnInput, weight[mlp_w1_weight_name], Data(), a1); Linear(attnInput, weight[mlp_w2_weight_name], Data(), a2); Silu(a2, a2); MulTo(a1, a2); Linear(a1, weight[mlp_proj_weight_name], Data(), mlpOutput); AddTo(hiddenStates, mlpOutput); } RMSNorm(hiddenStates, weight["transformer.ln_f.weight"], 1e-6, hiddenStates); Data logits, topk; Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); std::vector lastRet; int total = 0; Data curLogitTemp, curLogit; for (int b = 0; b < batch; b++) { Split(logits, 0, b, b + 1, curLogitTemp); Split(curLogitTemp, 1, maxLen - 1, maxLen, curLogit); if (generationConfig.output_logits && retLogits != nullptr && (*retLogits)[b] != nullptr) { curLogit.ToDevice(DataDevice::CPU); (*retLogits)[b]->resize(curLogit.Count(0)); memcpy((float*)(*retLogits)[b]->data(), (float*)curLogit.cpuData, curLogit.GetBytes()); } if (generationConfig.IsSimpleGreedy()) { Data topk; TopK(curLogit, topk, 1); topk.ToDevice(DataDevice::CPU); lastRet.push_back((int) (((float *) topk.cpuData)[0] + 1e-3)); } else { lastRet.push_back(LLMSampling(curLogit, 0, generationConfig, lastTokens.units[b])); } total += maxLen; } return lastRet; } std::vector QWenModel::ForwardBatch(int batch, const Data &inputIds, const std::vector &attentionMask, const std::vector &positionIds, const std::vector &seqLens, std::vector > &pastKeyValues, const std::vector &generationConfigs, const LastTokensManager &lastTokens, std::vector *> *retLogits) { int maxLen = inputIds.dims[1]; Data hiddenStates; Data attnInput, attnOutput; Data query, key, value; Data attnWeights, attnLastOutput; Data a1, a2, mlpOutput; Embedding(inputIds, this->weight["transformer.wte.weight"], hiddenStates); for (int i = 0; i < this->block_cnt; i++) { ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); std::string ln_1_name = "transformer.h." + std::to_string(i) + ".ln_1.weight"; std::string attn_weight_name = "transformer.h." + std::to_string(i) + ".attn.c_attn.weight"; std::string attn_bias_name = "transformer.h." + std::to_string(i) + ".attn.c_attn.bias"; RMSNorm(hiddenStates, weight[ln_1_name], 1e-6, attnInput); Linear(attnInput, weight[attn_weight_name], weight[attn_bias_name], attnOutput); // attnOutput [batch, seqlen, embed_dim * 3] Split(attnOutput, 2, 0, embed_dim, query); Split(attnOutput, 2, embed_dim, 2 * embed_dim, key); Split(attnOutput, 2, embed_dim * 2, embed_dim * 3, value); std::vector curKs, curVs, curQs; curKs.resize(batch); curVs.resize(batch); curQs.resize(batch); int total = 0; for (int b = 0; b < batch; b++) { Split(query, 1, total, total + seqLens[b], curQs[b]); Split(key, 1, total, total + seqLens[b], curKs[b]); Split(value, 1, total, total + seqLens[b], curVs[b]); total += seqLens[b]; } Data attnOutputAll = Data(DataType::FLOAT32); for (int b = 0; b < batch; b++) { // in this loop, batch = 1 auto &query = curQs[b]; auto &key = curKs[b]; auto &value = curVs[b]; query.Reshape({1, seqLens[b], num_attention_heads, head_dim}); key.Reshape({1, seqLens[b], num_attention_heads, head_dim}); value.Reshape({1, seqLens[b], num_attention_heads, head_dim}); Data &pastKey = *pastKeyValues[b * block_cnt + i].first, &pastValue = *pastKeyValues[b * block_cnt + i].second; if (pastKey.dims.empty()) { // 计算new_ntk_alpha float context_value = std::log2((float) seqLens[b] / seq_length) + 1; float new_ntk_alpha = std::max(std::pow(2, std::ceil(context_value) - 1), 1.); if (new_ntk_alpha != ntk_alpha) { UpdateRotaryPosEmb(new_ntk_alpha); } } LlamaRotatePosition2D(query, *positionIds[b], sinData, cosData, rotary_dim); LlamaRotatePosition2D(key, *positionIds[b], sinData, cosData, rotary_dim); if (use_log_attn) { ApplyLognAttn(query, logn_list, *positionIds[b]); } PermuteSelf(query, {0, 2, 1, 3}); PermuteSelf(key, {0, 2, 1, 3}); PermuteSelf(value, {0, 2, 1, 3}); std::vector qkvSize = {num_attention_heads, seqLens[b], -1}; query.Reshape(qkvSize); key.Reshape(qkvSize); value.Reshape(qkvSize); int unitLen = 64; #ifdef USE_CUDA unitLen = 128; #endif while ((pastKey.dims.size() == 0 && (pastKey.expansionDims.size() == 0 || key.dims[1] > pastKey.expansionDims[1])) || (pastKey.dims.size() > 0 && pastKey.dims[1] + key.dims[1] > pastKey.expansionDims[1])) { std::vector newDims; if (pastKey.Count(0) == 0 || pastKey.dims.size() == 0) { newDims = std::vector {key.dims[0], ((key.dims[1] - 1) / unitLen + 1) * unitLen, key.dims[2]}; } else { newDims = pastKey.dims; newDims[1] += ((key.dims[1] - 1) / unitLen + 1) * unitLen; } pastKey.Expansion(newDims); } while ((pastValue.dims.size() == 0 && (pastValue.expansionDims.size() == 0 || value.dims[1] > pastValue.expansionDims[1])) || (pastValue.dims.size() > 0 && pastValue.dims[1] + value.dims[1] > pastValue.expansionDims[1])) { std::vector newDims; if (pastValue.Count(0) == 0 || pastValue.dims.size() == 0) { newDims = std::vector {value.dims[0], ((value.dims[1] - 1) / unitLen + 1) * unitLen, value.dims[2]}; } else { newDims = pastValue.dims; newDims[1] += ((value.dims[1] - 1) / unitLen + 1) * unitLen; } pastValue.Expansion(newDims); } CatDirect(pastKey, key, 1); CatDirect(pastValue, value, 1); MatMulTransB(query, pastKey, attnWeights, 1.0 / sqrt(head_dim)); attnWeights.Reshape({1, attnWeights.dims[0], attnWeights.dims[1], attnWeights.dims[2]}); if (attentionMask[b]) { AttentionMask(attnWeights, *attentionMask[b], -10000); } Softmax(attnWeights, attnWeights, -1); MatMul(attnWeights, pastValue, attnOutput); attnOutput.Reshape({attnOutput.dims[1], attnOutput.dims[2], attnOutput.dims[3]}); PermuteSelf(attnOutput, {1, 0, 2}); attnOutput.Reshape({seqLens[b], 1, -1}); PermuteSelf(attnOutput, {1, 0, 2}); if (attnOutputAll.dims.size() == 0) { std::vector dims = attnOutput.dims; dims[1] = total; attnOutputAll.Expansion(dims); } CatDirect(attnOutputAll, attnOutput, 1); } std::string proj_weight_name = "transformer.h." + std::to_string(i) + ".attn.c_proj.weight"; Linear(attnOutputAll, weight[proj_weight_name], Data(), attnLastOutput); AddTo(hiddenStates, attnLastOutput); std::string ln_2_name = "transformer.h." + std::to_string(i) + ".ln_2.weight"; RMSNorm(hiddenStates, weight[ln_2_name], 1e-6, attnInput); std::string mlp_w1_weight_name = "transformer.h." + std::to_string(i) + ".mlp.w1.weight"; std::string mlp_w2_weight_name = "transformer.h." + std::to_string(i) + ".mlp.w2.weight"; std::string mlp_proj_weight_name = "transformer.h." + std::to_string(i) + ".mlp.c_proj.weight"; Linear(attnInput, weight[mlp_w1_weight_name], Data(), a1); Linear(attnInput, weight[mlp_w2_weight_name], Data(), a2); Silu(a2, a2); MulTo(a1, a2); Linear(a1, weight[mlp_proj_weight_name], Data(), mlpOutput); AddTo(hiddenStates, mlpOutput); } RMSNorm(hiddenStates, weight["transformer.ln_f.weight"], 1e-6, hiddenStates); Data logits; Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); std::vector lastRet; int total = 0; Data curLogit; for (int b = 0; b < batch; b++) { Split(logits, 1, total + seqLens[b] - 1, total + seqLens[b], curLogit); if (generationConfigs[b].output_logits && retLogits != nullptr && (*retLogits)[b] != nullptr) { curLogit.ToDevice(DataDevice::CPU); (*retLogits)[b]->resize(curLogit.Count(0)); memcpy((float*)(*retLogits)[b]->data(), (float*)curLogit.cpuData, curLogit.GetBytes()); } if (generationConfigs[b].IsSimpleGreedy()) { Data topk; TopK(curLogit, topk, 1); topk.ToDevice(DataDevice::CPU); lastRet.push_back((int) (((float *) topk.cpuData)[0] + 1e-3)); } else { lastRet.push_back(LLMSampling(curLogit, 0, generationConfigs[b], lastTokens.units[b])); } total += seqLens[b]; } return lastRet; } std::string QWenModel::MakeInput(const std::string &history, int round, const std::string &input) { if (weight.dicts["chat_format"] == "chatml") { return (round == 0 ? im_start + "system" + "\n" + pre_prompt + im_end : history) + "\n" + im_start + user_role + "\n" + input + im_end + "\n" + im_start + bot_role + "\n"; } else if (weight.dicts["chat_format"] == "raw") { return history + input; } else { ErrorInFastLLM("Unknown char_format for QWen: " + weight.dicts["chat_format"]); return ""; } } std::string QWenModel::MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output) { if (weight.dicts["chat_format"] == "chatml") { return (round == 0 ? im_start + "system" + "\n" + pre_prompt + im_end : history) + "\n" + im_start + user_role + "\n" + input + im_end + "\n" + im_start + bot_role + "\n" + output + im_end; } else if (weight.dicts["chat_format"] == "raw") { return history + input + output; } else { ErrorInFastLLM("Unknown char_format for QWen: " + weight.dicts["chat_format"]); return ""; } } void QWenModel::FillLLMInputs(std::vector > &inputTokens, const std::map ¶ms, Data &inputIds, Data &attentionMask, Data &positionIds) { int index = params.find("index")->second; int promptLen = params.find("promptLen")->second; inputIds.ToDevice(DataDevice::CPU); attentionMask.ToDevice(DataDevice::CPU); positionIds.ToDevice(DataDevice::CPU); if (index == 0) { int seqLen = inputTokens[0].size(); std::vector vmask = std::vector (seqLen * seqLen, 0); std::vector vpids = std::vector(seqLen, 0); for (int i = 0; i < seqLen; i++) { vpids[i] = i; for (int j = i + 1; j < seqLen; j++) { vmask[i * seqLen + j] = 1; } } inputIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, inputTokens[0])); attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, seqLen}, vmask)); positionIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, vpids)); } else { inputIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, inputTokens[0])); attentionMask.CopyFrom(Data()); positionIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, {(float) (promptLen + index - 1)})); } } void QWenModel::FillLLMInputsBatch(std::vector > &inputTokens, const std::vector > ¶ms, Data &inputIds, Data &attentionMask, Data &positionIds) { int batch = inputTokens.size(); int index = params[0].find("index")->second; int promptLen = params[0].find("promptLen")->second; inputIds.ToDevice(DataDevice::CPU); attentionMask.ToDevice(DataDevice::CPU); positionIds.ToDevice(DataDevice::CPU); if (index == 0) { std::vector seqLens; seqLens.resize(batch); int maxLen = 0; for (int i = 0; i < batch; i++) { maxLen = std::max(maxLen, (int) inputTokens[i].size()); seqLens[i] = (int) inputTokens[i].size(); } int seqLen = maxLen; std::vector ids = std::vector(batch * seqLen, 151643); std::vector vpids = std::vector(batch * seqLen, 0); std::vector vmask = std::vector (batch * seqLen * seqLen, 0); for (int b = 0; b < batch; b++) { auto &tokens = inputTokens[b]; int len = tokens.size(), base = maxLen - len; for (int i = 0; i < len; i++) { ids[b * seqLen + base + i] = inputTokens[b][i]; vpids[b * seqLen + base + i] = i; } std::fill(vmask.data() + b * maxLen * maxLen, vmask.data() + b * maxLen * maxLen + (maxLen - len) * maxLen, 1.0); for (int j = maxLen - len; j < maxLen; j++) { std::fill(vmask.data() + b * maxLen * maxLen + j * maxLen, vmask.data() + b * maxLen * maxLen + j * maxLen + maxLen - len, 1.0); } for (int j = 0; j < len; j++) { for (int k = j + 1; k < len; k++) { vmask[b * maxLen * maxLen + (base + j) * maxLen + base + k] = 1; } } } inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, seqLen}, ids)); attentionMask.CopyFrom(Data(DataType::FLOAT32, {batch, seqLen, seqLen}, vmask)); positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, seqLen}, vpids)); } else { std::vector ids = std::vector(batch * 1, 0); std::vector vpids = std::vector(batch * 1, 0); for (int b = 0; b < batch; b++) { ids[b] = inputTokens[b][0]; vpids[b] = (float) (promptLen + index - 1); } inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, ids)); attentionMask.CopyFrom(Data()); positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, vpids)); } } void QWenModel::WarmUp() { printf("Warmup...\n"); Data inputIds = Data(DataType::FLOAT32, {1, 1}, {1}); Data attentionMask = Data(DataType::FLOAT32, {1, 1}, {0}); Data positionIds = Data(DataType::FLOAT32, {1, 1}, {0, 0}); std::vector > pastKeyValues; for (int i = 0; i < block_cnt; i++) { pastKeyValues.push_back(std::make_pair(Data(DataType::FLOAT32), Data(DataType::FLOAT32))); } Forward(inputIds, attentionMask, positionIds, pastKeyValues); #ifdef USE_TFACC40T FastllmTfaccReleaseTempMemory(); #endif printf("finish.\n"); } void QWenModel::UpdateRotaryPosEmb(float ntk_alpha) { float base = 10000 * pow(ntk_alpha, (float) rotary_dim / (rotary_dim - 2)); if (sin.empty() || cos.empty()) { sin.resize(max_positions); cos.resize(max_positions); } std::vector invFreq; for (int i = 0; i < rotary_dim; i += 2) { invFreq.push_back(1.0 / pow(base, (float)i / rotary_dim)); } for (int i = 0; i < max_positions; i++) { sin[i].resize(rotary_dim); cos[i].resize(rotary_dim); for (int j = 0; j < invFreq.size(); j++) { sin[i][j] = ::sin((float)i * invFreq[j]); cos[i][j] = ::cos((float)i * invFreq[j]); } } std::vector fsin, fcos; for (int i = 0; i < sin.size(); i++) { for (int j = 0; j < sin[0].size(); j++) { fsin.push_back(sin[i][j]); fcos.push_back(cos[i][j]); } } sinData.ToDevice(DataDevice::CPU); cosData.ToDevice(DataDevice::CPU); sinData.CopyFrom(Data(DataType::FLOAT32, {(int)this->sin.size(), (int)this->sin[0].size()}, fsin)); cosData.CopyFrom(Data(DataType::FLOAT32, {(int)this->cos.size(), (int)this->cos[0].size()}, fcos)); } }