Unverified Commit 80909bee authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Issue/60 -main 修复输出token乱码并适配了qwen3模型

parent 753a4f60
......@@ -35,6 +35,10 @@ typedef struct
const void *const *attn_qkv;
// nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh]
const void *const *attn_qkv_b;
// nlayer * [dh]
const void *const *attn_q_norm;
// nlayer * [dh]
const void *const *attn_k_norm;
// nlayer * [ndev, d, nkvh / ndev * dh]
const void *const *attn_o;
// nlayer * [d]
......
......@@ -662,11 +662,7 @@ class DeepSeekV3ForCauslLM:
output_tokens = self.batch_infer_one_round([infer_task])
end_time = time.time()
steps += 1
output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
output_str = self.tokenizer.decode(output_tokens[0])
output_content += output_str
print(output_str, end="", flush=True)
if output_tokens[0] in self.eos_token_id:
......
......@@ -58,6 +58,12 @@ class LlamaWeightsNaming:
def attn_v_b(self, i):
return f"model.layers.{i}.self_attn.v_proj.bias"
def attn_q_norm(self, i):
return f"model.layers.{i}.self_attn.q_norm.weight"
def attn_k_norm(self, i):
return f"model.layers.{i}.self_attn.k_norm.weight"
def ffn_norm(self, i):
return f"model.layers.{i}.post_attention_layernorm.weight"
......@@ -117,7 +123,7 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
if "num_key_value_heads" in config
else config["num_attention_heads"]
),
dh=config["hidden_size"] // config["num_attention_heads"],
dh=config["head_dim"] if "head_dim" in config else config["hidden_size"] // config["num_attention_heads"],
di=config["intermediate_size"],
dctx=(
config["max_position_embeddings"] if max_tokens is None else max_tokens
......@@ -275,6 +281,35 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
else:
self.attn_qkv_b = None
if naming.attn_q_norm(0) in state_dict:
self.attn_q_norm_tensors = [
state_dict[naming.attn_q_norm(i)]
.reshape([2, dh // 2])
.transpose(0, 1)
.contiguous()
.to(torch_dt_norm)
for i in range(nlayer)
]
self.attn_q_norm_ptrs = [
self.attn_q_norm_tensors[i].data_ptr() for i in range(nlayer)
]
self.attn_q_norm = (c_void_p * nlayer)(*self.attn_q_norm_ptrs)
self.attn_k_norm_tensors = [
state_dict[naming.attn_k_norm(i)]
.reshape([2, dh // 2])
.transpose(0, 1)
.contiguous()
.to(torch_dt_norm)
for i in range(nlayer)
]
self.attn_k_norm_ptrs = [
self.attn_k_norm_tensors[i].data_ptr() for i in range(nlayer)
]
self.attn_k_norm = (c_void_p * nlayer)(*self.attn_k_norm_ptrs)
else:
self.attn_q_norm = None
self.attn_k_norm = None
self.attn_o_tensor = [
(
state_dict[naming.attn_o(i)]
......@@ -481,7 +516,7 @@ class JiugeForCauslLM:
)
else:
raise ValueError("Unsupported weight naming")
elif "qwen2" == config["model_type"]:
elif "qwen2" == config["model_type"] or "qwen3" == config["model_type"]:
state_dict = load_all_safetensors_from_dir(model_dir_path)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
......@@ -498,6 +533,24 @@ class JiugeForCauslLM:
else:
raise ValueError("Unsupported model architecture")
if "llama" == config["model_type"]:
from tokenizers import decoders as _dec
backend = getattr(self.tokenizer, "backend_tokenizer", None)
target = getattr(backend, "_tokenizer", backend)
norm = getattr(target, "normalizer", None)
dec = getattr(target, "decoder", None)
sn = repr(norm)[:800] if norm is not None else ""
sd = repr(dec)[:800] if dec is not None else ""
has_prepend = "Prepend" in sn
has_strip = "Strip" in sd
if has_prepend and has_strip:
target.decoder = _dec.Sequence([
_dec.Replace("▁", " "),
_dec.ByteFallback(),
_dec.Fuse(),
])
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")
......@@ -574,11 +627,8 @@ class JiugeForCauslLM:
output_tokens = self.batch_infer_one_round([infer_task])
end_time = time.time()
steps += 1
output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
output_str = self.tokenizer.decode(output_tokens[0])
output_content += output_str
print(output_str, end="", flush=True)
if output_tokens[0] in self.eos_token_id:
......
......@@ -256,11 +256,6 @@ class JiugeAWQForCausalLM:
output_tokens = self.batch_infer_one_round([infer_task])
end_time = time.time()
steps += 1
# output_str = (
# self.tokenizer._tokenizer.id_to_token(output_tokens[0])
# .replace("▁", " ")
# .replace("<0x0A>", "\n")
# )
output_str = self.tokenizer.decode(output_tokens[0])
output_content += output_str
print(output_str, end="", flush=True)
......
......@@ -226,11 +226,8 @@ async def chat_stream(id_, request_data, request: Request):
break
token = await infer_task.output_queue.async_q.get()
content = (
request.app.state.model.tokenizer._tokenizer.id_to_token(token)
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
content = request.app.state.model.tokenizer.decode(token)
chunk = json.dumps(chunk_json(id_, content=content), ensure_ascii=False)
yield f"data: {chunk}\n\n"
......@@ -255,11 +252,7 @@ async def chat(id_, request_data, request: Request):
break
token = await infer_task.output_queue.async_q.get()
content = (
request.app.state.model.tokenizer._tokenizer.id_to_token(token)
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
content = request.app.state.model.tokenizer.decode(token)
output.append(content)
output_text = "".join(output).strip()
......
......@@ -31,6 +31,8 @@ class JiugeWeightsCStruct(Structure):
("attn_norm", POINTER(c_void_p)),
("attn_qkv", POINTER(c_void_p)),
("attn_qkv_b", POINTER(c_void_p)),
("attn_q_norm", POINTER(c_void_p)),
("attn_k_norm", POINTER(c_void_p)),
("attn_o", POINTER(c_void_p)),
("ffn_norm", POINTER(c_void_p)),
("ffn_gate_up", POINTER(c_void_p)),
......
......@@ -43,11 +43,7 @@ class JiugeForCeval(JiugeForCauslLM):
output_tokens = self.batch_infer_one_round([infer_task])
end_time = time.time()
steps += 1
output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
output_str = self.tokenizer.decode(output_tokens[0])
output_content += output_str
print(output_str, end="", flush=True)
if output_tokens[0] in self.eos_token_id:
......
......@@ -21,7 +21,7 @@ void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta,
infinirtStream_t stream;
infinirtStreamCreate(&stream);
std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out,
std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm, w_attn_out,
w_ffn_norm, w_ffn_gate_up, w_ffn_down;
for (size_t layer = 0; layer < meta->nlayer; layer++) {
w_attn_norm.push_back(
......@@ -32,6 +32,13 @@ void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta,
b_attn_qkv.push_back(
getAttnQKVBias(meta, weights, layer, idev, ndev));
}
if (weights->attn_q_norm != nullptr) {
w_attn_q_norm.push_back(
getAttnQNorm(meta, weights, layer));
w_attn_k_norm.push_back(
getAttnKNorm(meta, weights, layer));
}
w_attn_out.push_back(
getAttnO(meta, weights, layer, idev, ndev));
w_ffn_norm.push_back(
......@@ -56,6 +63,8 @@ void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta,
w_attn_norm,
w_attn_qkv,
b_attn_qkv,
w_attn_q_norm,
w_attn_k_norm,
w_attn_out,
w_ffn_norm,
w_ffn_gate_up,
......@@ -130,6 +139,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
auto dvoc = meta.dvoc;
auto stream = rsrc.stream;
bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0;
bool has_qk_norm = rsrc.w_attn_q_norm.size() > 0 && rsrc.w_attn_k_norm.size() > 0;
// Allocate buffers
auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool);
......@@ -142,6 +152,8 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
auto result_cpu = std::vector<int64_t>(nreq);
auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh});
auto q_buf = qkv_rope->slice(1, 0, nh);
auto k_buf = qkv_rope->slice(1, nh, nkvh);
// Prepare inputs
auto batch_pos_ids = std::vector<uint32_t>(ntok);
......@@ -198,9 +210,15 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon);
// qkv_proj
linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, nullptr, has_qkv_bias ? rsrc.b_attn_qkv[layer] : nullptr);
if (has_qk_norm) {
rmsnorm(q_buf, q_buf, rsrc.w_attn_q_norm[layer], meta.epsilon);
rmsnorm(k_buf, k_buf, rsrc.w_attn_k_norm[layer], meta.epsilon);
}
// rope
rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(q_buf, q_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(k_buf, k_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
size_t token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
......@@ -299,11 +317,11 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
__C void
inferBatchJiuge(struct JiugeModel *model,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output) {
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output) {
model->req.tokens = tokens;
model->req.ntok = ntok;
model->req.req_lens = req_lens;
......@@ -332,10 +350,10 @@ inferBatchJiuge(struct JiugeModel *model,
__C void
forwardBatchJiuge(struct JiugeModel *model,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
void *logits) {
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
void *logits) {
model->req.tokens = tokens;
model->req.ntok = ntok;
model->req.req_lens = req_lens;
......
......@@ -20,7 +20,7 @@ struct JiugeDeviceResource {
// Weights
std::shared_ptr<Tensor> w_in_embd, w_out_norm, w_out_embd, sin_table,
cos_table;
std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out,
std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm,w_attn_out,
w_ffn_norm, w_ffn_gate_up, w_ffn_down;
// Streams
infinirtStream_t stream;
......
......@@ -70,6 +70,22 @@ inline std::shared_ptr<Tensor> getAttnQKVBias(
return Tensor::weight((char *)(w->attn_qkv_b[layer]) + offset, w->dt_mat, shape);
}
inline std::shared_ptr<Tensor> getAttnQNorm(
JiugeMeta const *meta,
JiugeWeights const *w,
size_t layer) {
auto shape = std::vector<size_t>({meta->dh});
return Tensor::weight((char *)(w->attn_q_norm[layer]), w->dt_norm, shape);
}
inline std::shared_ptr<Tensor> getAttnKNorm(
JiugeMeta const *meta,
JiugeWeights const *w,
size_t layer) {
auto shape = std::vector<size_t>({meta->dh});
return Tensor::weight((char *)(w->attn_k_norm[layer]), w->dt_norm, shape);
}
inline std::shared_ptr<Tensor> getAttnO(JiugeMeta const *meta,
JiugeWeights const *w, size_t layer,
size_t idev, size_t ndev) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment