Unverified Commit 3e3c2743 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #3 from InfiniTensor/kvcache_pool

支持多请求连续batch以及kvcache池
parents a73433ab 2b9ce5a6
...@@ -75,22 +75,6 @@ __C __export void ...@@ -75,22 +75,6 @@ __C __export void
dropKVCache(const struct JiugeModel *, dropKVCache(const struct JiugeModel *,
struct KVCache *); struct KVCache *);
/// @brief 文本生成
/// @param tokens 输入 token
/// @param ntok 输入 token 数量
/// @param req_pos 每个请求的起始位置
/// @param output 输出 token 地址
/// @param max_step 输出 token 最大数量
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
__C __export void
generate(struct JiugeModel *,
struct KVCache *,
const uint32_t *tokens, uint32_t ntok, uint32_t req_pos,
uint32_t *output, uint32_t max_step,
float temperature, uint32_t topk, float topp);
/// @brief 批次推理一轮 /// @brief 批次推理一轮
/// @param tokens 输入 token 地址 /// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量 /// @param ntok 输入 token 数量
...@@ -98,16 +82,16 @@ generate(struct JiugeModel *, ...@@ -98,16 +82,16 @@ generate(struct JiugeModel *,
/// @param req_lens 每个请求的 token 数量 /// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置 /// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache /// @param kv_caches 每个请求的 KV Cache
/// @param ans 输出 token 数组,每个请求一个输出,长度至少为nreq
/// @param temperature 采样温度(0. 表示贪心采样) /// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样) /// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp /// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void __C __export void
inferBatch(struct JiugeModel *, inferBatch(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches, struct KVCache **kv_caches,
uint32_t *output, const float *temperature, const uint32_t *topk, const float *topp,
float temperature, uint32_t topk, float topp); uint32_t *output);
#endif #endif
class InferTask:
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens):
self.id = id
self.finish_reason = None
self.tokens = tokens
self.max_tokens = max_tokens
self.temperature = temperature
self.topk = topk
self.topp = topp
self.end_tokens = end_tokens
self._kv_cache = None
self.pos = 0
def bind_kvcache(self, kv_cache, pos=0):
self._kv_cache = kv_cache
self.pos = pos
self.tokens = self.tokens[pos:]
def release_kvcache(self):
cache = self._kv_cache
self._kv_cache = None
return cache
def kvcache(self):
return self._kv_cache
def next(self, out_token):
self._kv_cache.update_tokens(self.tokens, self.pos)
self.pos += len(self.tokens)
if out_token == None or out_token in self.end_tokens:
self.finish_reason = "stop"
elif self.pos >= self.max_tokens:
self.finish_reason = "length"
else:
self.tokens = [out_token]
class KVCache:
def __init__(self, model):
self._kvcache = model.create_kv_cache()
self.tokens = [0 for _ in range(model.max_context_len())]
def data(self):
return self._kvcache
def drop(self, model):
model.drop_kv_cache(self._kvcache)
def update_tokens(self, tokens, pos):
end = pos + len(tokens)
max_len = len(self.tokens)
# If overflow, truncate tokens to fit
if end > max_len:
tokens = tokens[: max_len - pos]
end = max_len
self.tokens[pos:end] = tokens
from ctypes import POINTER, c_int, c_uint, c_void_p, byref from typing import List
import os
from pathlib import Path
import safetensors
import sys
import time
import json
import asyncio
from libinfinicore_infer import ( from libinfinicore_infer import (
JiugeMeta, JiugeMetaCStruct,
JiugeWeights, JiugeWeightsCStruct,
KVCache, KVCacheCStruct,
DataType, DataType,
DeviceType, DeviceType,
create_jiuge_model, create_jiuge_model,
...@@ -19,11 +11,21 @@ from libinfinicore_infer import ( ...@@ -19,11 +11,21 @@ from libinfinicore_infer import (
drop_kv_cache, drop_kv_cache,
infer_batch, infer_batch,
) )
from infer_task import InferTask, KVCache
from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import os
from pathlib import Path
import safetensors
import sys
import time
import json
import torch import torch
import transformers import transformers
torch.set_default_device("cpu") torch.set_default_device("cpu")
class LlamaWeightsNaming: class LlamaWeightsNaming:
def input_embd(self): def input_embd(self):
return "model.embed_tokens.weight" return "model.embed_tokens.weight"
...@@ -77,7 +79,7 @@ class LlamaWeightsNaming: ...@@ -77,7 +79,7 @@ class LlamaWeightsNaming:
) )
class JiugeMetaFromLlama(JiugeMeta): class JiugeMetaFromLlama(JiugeMetaCStruct):
def __init__(self, config, dtype=torch.float16): def __init__(self, config, dtype=torch.float16):
if dtype == torch.float16: if dtype == torch.float16:
dt_ = DataType.INFINI_DTYPE_F16 dt_ = DataType.INFINI_DTYPE_F16
...@@ -106,7 +108,7 @@ class JiugeMetaFromLlama(JiugeMeta): ...@@ -106,7 +108,7 @@ class JiugeMetaFromLlama(JiugeMeta):
self.torch_dtype_logits = dtype self.torch_dtype_logits = dtype
class JiugeWeightsImpl(JiugeWeights): class JiugeWeightsImpl(JiugeWeightsCStruct):
def __init__( def __init__(
self, self,
meta, meta,
...@@ -159,7 +161,9 @@ class JiugeWeightsImpl(JiugeWeights): ...@@ -159,7 +161,9 @@ class JiugeWeightsImpl(JiugeWeights):
self.output_norm = self.output_norm_tensor.data_ptr() self.output_norm = self.output_norm_tensor.data_ptr()
self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat)
if not transpose_weight: if not transpose_weight:
self.output_embd_tensor = self.output_embd_tensor.transpose(0, 1).contiguous() self.output_embd_tensor = self.output_embd_tensor.transpose(
0, 1
).contiguous()
self.output_embd = self.output_embd_tensor.data_ptr() self.output_embd = self.output_embd_tensor.data_ptr()
self.attn_norm_tensors = [ self.attn_norm_tensors = [
...@@ -196,7 +200,12 @@ class JiugeWeightsImpl(JiugeWeights): ...@@ -196,7 +200,12 @@ class JiugeWeightsImpl(JiugeWeights):
] ]
if not transpose_weight: if not transpose_weight:
for i in range(nlayer): for i in range(nlayer):
self.qkv_tensor[i] = self.qkv_tensor[i].reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d).transpose(1, 2).contiguous() self.qkv_tensor[i] = (
self.qkv_tensor[i]
.reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d)
.transpose(1, 2)
.contiguous()
)
self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)] self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)]
self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs) self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs)
...@@ -233,13 +242,18 @@ class JiugeWeightsImpl(JiugeWeights): ...@@ -233,13 +242,18 @@ class JiugeWeightsImpl(JiugeWeights):
self.attn_qkv_b = None self.attn_qkv_b = None
self.attn_o_tensor = [ self.attn_o_tensor = [
state_dict[naming.attn_o(i)] (
state_dict[naming.attn_o(i)]
.to(torch_dt_mat) .to(torch_dt_mat)
.reshape([d, ndev, nh // ndev * dh]) .reshape([d, ndev, nh // ndev * dh])
.transpose(0, 1) .transpose(0, 1)
.contiguous() .contiguous()
if transpose_weight if transpose_weight
else state_dict[naming.attn_o(i)].transpose(0, 1).to(torch_dt_mat).contiguous() else state_dict[naming.attn_o(i)]
.transpose(0, 1)
.to(torch_dt_mat)
.contiguous()
)
for i in range(nlayer) for i in range(nlayer)
] ]
self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)] self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)]
...@@ -268,24 +282,75 @@ class JiugeWeightsImpl(JiugeWeights): ...@@ -268,24 +282,75 @@ class JiugeWeightsImpl(JiugeWeights):
] ]
if not transpose_weight: if not transpose_weight:
for i in range(nlayer): for i in range(nlayer):
self.gate_up_tensors[i] = self.gate_up_tensors[i].reshape(ndev, 2 * di // ndev, d).transpose(1, 2).contiguous() self.gate_up_tensors[i] = (
self.gate_up_tensors[i]
.reshape(ndev, 2 * di // ndev, d)
.transpose(1, 2)
.contiguous()
)
self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)] self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)]
self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs) self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs)
self.ffn_down_tensor = [ self.ffn_down_tensor = [
state_dict[naming.down(i)] (
.to(torch_dt_mat) state_dict[naming.down(i)]
.reshape([d, ndev, di // ndev]) .to(torch_dt_mat)
.transpose(0, 1) .reshape([d, ndev, di // ndev])
.contiguous() .transpose(0, 1)
if transpose_weight .contiguous()
else state_dict[naming.down(i)].transpose(0, 1).to(torch_dt_mat).contiguous() if transpose_weight
else state_dict[naming.down(i)]
.transpose(0, 1)
.to(torch_dt_mat)
.contiguous()
)
for i in range(nlayer) for i in range(nlayer)
] ]
self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)] self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)]
self.ffn_down = (c_void_p * nlayer)(*self.ffn_down_ptrs) self.ffn_down = (c_void_p * nlayer)(*self.ffn_down_ptrs)
class JiugeBatchedTask:
def __init__(self, tasks: List[InferTask]):
self.tasks = tasks
self.nreq = len(tasks)
# Precompute fields
token_lists = [t.tokens for t in tasks]
self.req_lens_list = [len(toks) for toks in token_lists]
self.req_pos_list = [t.pos for t in tasks]
self.kv_cache_ptrs = [t.kvcache().data() for t in tasks]
self.temperaturas_list = [t.temperature for t in tasks]
self.topks_list = [t.topk for t in tasks]
self.topps_list = [t.topp for t in tasks]
# Flatten token lists
flat_tokens = [tok for toks in token_lists for tok in toks]
self.ntok = len(flat_tokens)
# Convert to ctypes arrays in one pass
self.tokens = (c_uint * self.ntok)(*flat_tokens)
self.req_lens = (c_uint * self.nreq)(*self.req_lens_list)
self.req_pos = (c_uint * self.nreq)(*self.req_pos_list)
self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs)
self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list)
self.topks = (c_uint * self.nreq)(*self.topks_list)
self.topps = (c_float * self.nreq)(*self.topps_list)
def input_args(self):
return (
self.tokens,
self.ntok,
self.req_lens,
self.nreq,
self.req_pos,
self.kv_caches,
self.temperaturas,
self.topks,
self.topps,
)
class JiugeForCauslLM: class JiugeForCauslLM:
def __init__(self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1): def __init__(self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1):
def load_all_safetensors_from_dir(dir_path_: str): def load_all_safetensors_from_dir(dir_path_: str):
...@@ -296,7 +361,7 @@ class JiugeForCauslLM: ...@@ -296,7 +361,7 @@ class JiugeForCauslLM:
for name_ in data_.keys(): for name_ in data_.keys():
tensors_[name_] = data_.get_tensor(name_) tensors_[name_] = data_.get_tensor(name_)
return tensors_ return tensors_
print("Loading model weights to host...") print("Loading model weights to host...")
load_start_time = time.time() load_start_time = time.time()
...@@ -304,26 +369,46 @@ class JiugeForCauslLM: ...@@ -304,26 +369,46 @@ class JiugeForCauslLM:
config = json.load(f) config = json.load(f)
self.config = config self.config = config
eos_token_id = self.config["eos_token_id"] eos_token_id = self.config["eos_token_id"]
self.eos_token_id = [eos_token_id] if type(eos_token_id) == int else eos_token_id self.eos_token_id = (
transpose_weight = device != DeviceType.DEVICE_TYPE_ASCEND # y = xW is faster than y=xW^T on Ascend [eos_token_id] if type(eos_token_id) == int else eos_token_id
)
transpose_weight = (
device != DeviceType.DEVICE_TYPE_ASCEND
) # y = xW is faster than y=xW^T on Ascend
if "llama" == config["model_type"]: if "llama" == config["model_type"]:
model = transformers.LlamaForCausalLM.from_pretrained(model_dir_path).cpu().half() model = (
transformers.LlamaForCausalLM.from_pretrained(model_dir_path)
.cpu()
.half()
)
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), model.state_dict(), ndev=ndev, transpose_weight=transpose_weight self.meta,
LlamaWeightsNaming(),
model.state_dict(),
ndev=ndev,
transpose_weight=transpose_weight,
) )
elif "fm9g" == config["model_type"]: elif "fm9g" == config["model_type"]:
if any(file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()): if any(
file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()
):
state_dict = load_all_safetensors_from_dir(model_dir_path) state_dict = load_all_safetensors_from_dir(model_dir_path)
else: else:
state_dict = torch.load( state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu" os.path.join(model_dir_path, "pytorch_model.bin"),
) weights_only=True,
map_location="cpu",
)
if LlamaWeightsNaming.match(state_dict): if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight self.meta,
LlamaWeightsNaming(),
state_dict,
ndev=ndev,
transpose_weight=transpose_weight,
) )
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True model_dir_path, trust_remote_code=True
...@@ -332,12 +417,18 @@ class JiugeForCauslLM: ...@@ -332,12 +417,18 @@ class JiugeForCauslLM:
raise ValueError("Unsupported weight naming") raise ValueError("Unsupported weight naming")
elif "fm9g7b" == config["model_type"]: elif "fm9g7b" == config["model_type"]:
state_dict = torch.load( state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu" os.path.join(model_dir_path, "pytorch_model.bin"),
weights_only=True,
map_location="cpu",
) )
if LlamaWeightsNaming.match(state_dict): if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight self.meta,
LlamaWeightsNaming(),
state_dict,
ndev=ndev,
transpose_weight=transpose_weight,
) )
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True model_dir_path, trust_remote_code=True
...@@ -349,7 +440,11 @@ class JiugeForCauslLM: ...@@ -349,7 +440,11 @@ class JiugeForCauslLM:
if LlamaWeightsNaming.match(state_dict): if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight self.meta,
LlamaWeightsNaming(),
state_dict,
ndev=ndev,
transpose_weight=transpose_weight,
) )
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path model_dir_path
...@@ -359,7 +454,7 @@ class JiugeForCauslLM: ...@@ -359,7 +454,7 @@ class JiugeForCauslLM:
load_end_time = time.time() load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s") print(f"Time used: {load_end_time - load_start_time:.3f}s")
print(f"Creating model on {ndev} devices...") print(f"Creating model on {ndev} devices...")
load_start_time = time.time() load_start_time = time.time()
dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
...@@ -372,126 +467,27 @@ class JiugeForCauslLM: ...@@ -372,126 +467,27 @@ class JiugeForCauslLM:
) )
load_end_time = time.time() load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s") print(f"Time used: {load_end_time - load_start_time:.3f}s")
def max_context_len(self):
return self.meta.dctx
def create_kv_cache(self): def create_kv_cache(self):
return create_kv_cache(self.model_instance) return create_kv_cache(self.model_instance)
def drop_kv_cache(self, kv_cache): def drop_kv_cache(self, kv_cache):
drop_kv_cache(self.model_instance, kv_cache) drop_kv_cache(self.model_instance, kv_cache)
def chat(self, request, kv_cache): def batch_infer_one_round(self, tasks: List[InferTask]):
messages = request.get("messages", []) output = (c_uint * len(tasks))()
temperature = request.get("temperature", 1.0) batch_inputs = JiugeBatchedTask(tasks)
topk = request.get("top_k", 1) infer_batch(
topp = request.get("top_p", 1.0) self.model_instance,
max_tokens = request.get("max_tokens", self.meta.dctx) *(batch_inputs.input_args()),
input_content = self.tokenizer.apply_chat_template( output,
conversation=messages,
add_generation_prompt=True,
tokenize=False,
) )
return list(output)
tokens = self.tokenizer.encode(input_content) def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0):
ntok = len(tokens)
nreq = 1
output_content = ""
tokens = (c_uint * ntok)(*tokens)
req_lens = (c_uint * nreq)(*[ntok])
req_pos = (c_uint * nreq)(*[0])
kv_caches = (POINTER(KVCache) * nreq)(*[kv_cache])
ans = (c_uint * nreq)()
steps = 0
for step_i in range(max_tokens):
infer_batch(
self.model_instance,
tokens,
ntok,
req_lens,
nreq,
req_pos,
kv_caches,
ans,
temperature,
topk,
topp,
)
steps += 1
output_tokens = list(ans)
output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
output_content += output_str
if output_tokens[0] in self.eos_token_id:
break
req_pos[0] = req_pos[0] + ntok
ntok = 1
tokens = (c_uint * ntok)(*output_tokens)
req_lens = (c_uint * nreq)(*[ntok])
return output_content
async def chat_stream_async(self, request, kv_cache):
messages = request.get("messages", [])
temperature = request.get("temperature", 1.0)
topk = request.get("top_k", 1)
topp = request.get("top_p", 1.0)
max_tokens = request.get("max_tokens", 512)
input_content = self.tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
tokenize=False,
)
tokens = self.tokenizer.encode(input_content)
ntok = len(tokens)
nreq = 1
tokens = (c_uint * ntok)(*tokens)
req_lens = (c_uint * nreq)(*[ntok])
req_pos = (c_uint * nreq)(*[0])
kv_caches = (POINTER(KVCache) * nreq)(*[kv_cache])
ans = (c_uint * nreq)()
for step_i in range(max_tokens):
infer_batch(
self.model_instance,
tokens,
ntok,
req_lens,
nreq,
req_pos,
kv_caches,
ans,
temperature,
topk,
topp,
)
output_tokens = list(ans)
output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
yield output_str # Yield each token as it's produced
await asyncio.sleep(0) # Let event loop breathe
if output_tokens[0] in self.eos_token_id:
break
req_pos[0] += ntok
ntok = 1
tokens = (c_uint * ntok)(*output_tokens)
req_lens = (c_uint * nreq)(*[ntok])
def generate(self, input_content, max_steps, topp=1.0, topk=1, temperature=1.0):
kv_cache = create_kv_cache(self.model_instance)
input_content = self.tokenizer.apply_chat_template( input_content = self.tokenizer.apply_chat_template(
conversation=[{"role": "user", "content": input_content}], conversation=[{"role": "user", "content": input_content}],
add_generation_prompt=True, add_generation_prompt=True,
...@@ -499,36 +495,26 @@ class JiugeForCauslLM: ...@@ -499,36 +495,26 @@ class JiugeForCauslLM:
) )
print(input_content, end="", flush=True) print(input_content, end="", flush=True)
tokens = self.tokenizer.encode(input_content) tokens = self.tokenizer.encode(input_content)
ntok = len(tokens) infer_task = InferTask(
nreq = 1 0,
output_content = "" tokens,
tokens = (c_uint * ntok)(*tokens) self.max_context_len(),
req_lens = (c_uint * nreq)(*[ntok]) temperature_,
req_pos = (c_uint * nreq)(*[0]) topk_,
kv_caches = (POINTER(KVCache) * nreq)(*[kv_cache]) topp_,
ans = (c_uint * nreq)() self.eos_token_id,
)
infer_task.bind_kvcache(KVCache(self))
steps = 0 steps = 0
total_time = 0 total_time = 0
output_content = ""
for step_i in range(max_steps): for step_i in range(max_steps):
start_time = time.time() start_time = time.time()
infer_batch( output_tokens = self.batch_infer_one_round([infer_task])
self.model_instance,
tokens,
ntok,
req_lens,
nreq,
req_pos,
kv_caches,
ans,
temperature,
topk,
topp,
)
steps += 1
output_tokens = list(ans)
end_time = time.time() end_time = time.time()
steps += 1
output_str = ( output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0]) self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ") .replace("▁", " ")
...@@ -538,21 +524,18 @@ class JiugeForCauslLM: ...@@ -538,21 +524,18 @@ class JiugeForCauslLM:
print(output_str, end="", flush=True) print(output_str, end="", flush=True)
if output_tokens[0] in self.eos_token_id: if output_tokens[0] in self.eos_token_id:
break break
req_pos[0] = req_pos[0] + ntok infer_task.next(output_tokens[0])
ntok = 1
tokens = (c_uint * ntok)(*output_tokens)
req_lens = (c_uint * nreq)(*[ntok])
if step_i > 0: if step_i > 0:
total_time += end_time - start_time total_time += end_time - start_time
print("\n") print("\n")
avg_time = total_time * 1000 / (steps - 1) avg_time = total_time * 1000 / (steps - 1)
print(f"Time per step: {avg_time:.3f}ms") print(f"Time per step: {avg_time:.3f}ms")
for kv_cache in kv_caches:
drop_kv_cache(self.model_instance, kv_cache) infer_task._kv_cache.drop(self)
return output_content, avg_time return output_content, avg_time
def destroy_model_instance(self): def destroy_model_instance(self):
destroy_jiuge_model(self.model_instance) destroy_jiuge_model(self.model_instance)
print("Model destroyed") print("Model destroyed")
......
from infer_task import KVCache
import asyncio
from typing import List
import threading
class KVCachePool:
def __init__(self, model, max_caches: int = 32):
self.max_caches = max_caches
self.model = model
self._available: List[KVCache] = []
self.num_caches = len(self._available)
self._lock = threading.Lock()
self._not_empty = threading.Condition(self._lock)
self._shutdown = False
def acquire_sync(self, infer_task):
with self._not_empty:
while True:
if self._shutdown:
raise RuntimeError(
"KVCachePool is shutting down; cannot acquire new cache."
)
if len(self._available) == 0:
if self.num_caches < self.max_caches:
self.num_caches += 1
print(
f"[INFO] Task {infer_task.id} created new KVCachePoolItem"
)
return infer_task.bind_kvcache(KVCache(self.model), 0)
else:
self._not_empty.wait()
else:
max_match, max_match_index = self.find_most_matching_cache(
infer_task.tokens
)
kvcache = self._available.pop(max_match_index)
print(
f"[INFO] Task {infer_task.id} reused KVCachePoolItem {max_match_index} with {max_match} matches"
)
return infer_task.bind_kvcache(kvcache, max_match)
def release_sync(self, infer_task):
with self._not_empty:
print(f"[INFO] Task {infer_task.id} returned KVCachePoolItem to pool")
self._available.append(infer_task.release_kvcache())
self._not_empty.notify()
async def acquire(self, infer_task):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.acquire_sync, infer_task)
async def release(self, infer_task):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.release_sync, infer_task)
def find_most_matching_cache(self, tokens: List[int]):
max_match = 0
max_match_index = 0
def first_different_index(a_, b_):
for i_, (x_, y_) in enumerate(zip(a_, b_)):
if x_ != y_:
return i_
return min(len(a_), len(b_))
for i, kvcache in enumerate(self._available):
common_elements = first_different_index(tokens, kvcache.tokens)
# print(f"{tokens}")
# print(f"{kvcache.tokens[:len(tokens)]}")
if common_elements > max_match:
max_match = common_elements
max_match_index = i
return (min(max_match, len(tokens) - 1), max_match_index)
def finalize(self):
with self._not_empty:
self._shutdown = True
while len(self._available) < self.num_caches:
self._not_empty.wait()
for kvcache in self._available:
if kvcache is not None:
kvcache.drop(self.model)
self._available.clear()
self.max_caches = 0
self.num_caches = 0
from jiuge import JiugeForCauslLM from jiuge import JiugeForCauslLM
from libinfinicore_infer import DeviceType from libinfinicore_infer import DeviceType
from infer_task import InferTask
from kvcache_pool import KVCachePool
import queue
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse
import anyio import contextlib
import uvicorn import uvicorn
import time import time
import uuid import uuid
import sys import sys
import signal
import json import json
import threading
import janus
if len(sys.argv) < 3: if len(sys.argv) < 3:
print( print(
...@@ -37,26 +41,6 @@ else: ...@@ -37,26 +41,6 @@ else:
sys.exit(1) sys.exit(1)
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
model = JiugeForCauslLM(model_path, device_type, ndev)
kv_cache = model.create_kv_cache()
def signal_handler(sig, frame):
print(f"Received signal {sig}, cleaning up...")
model.drop_kv_cache(kv_cache)
model.destroy_model_instance()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle docker stop / system shutdown
app = FastAPI()
# TO REMOVE: Global lock to ensure only one request is handled at a time
# Remove this after multiple requests handling is implemented
request_lock = anyio.Lock()
def chunk_json(id_, content=None, role=None, finish_reason=None): def chunk_json(id_, content=None, role=None, finish_reason=None):
delta = {} delta = {}
...@@ -81,46 +65,178 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): ...@@ -81,46 +65,178 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
} }
MAX_BATCH = 3
print(
f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs."
)
# A wrapper for InferTask that supports async output queue
class AsyncInferTask(InferTask):
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens):
super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens)
self.output_queue = janus.Queue()
print(f"[INFO] Create InferTask {self.id}")
def output(self, out_token):
self.next(out_token)
self.output_queue.sync_q.put(out_token)
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
app.state.model = JiugeForCauslLM(model_path, device_type, ndev)
app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH)
app.state.request_queue = janus.Queue()
worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True)
worker_thread.start()
try:
yield # The app runs here
finally:
# Shutdown
app.state.request_queue.sync_q.put(None)
worker_thread.join()
app.state.request_queue.shutdown()
app.state.kv_cache_pool.finalize()
app.state.model.destroy_model_instance()
App = FastAPI(lifespan=lifespan)
# App loop: take requests from the queue, do inference, and put unfinished requests back into the queue.
def worker_loop(app):
while True:
try:
task = app.state.request_queue.sync_q.get(timeout=0.01)
except queue.Empty:
continue
if task is None:
return
batch = [task]
while len(batch) < MAX_BATCH:
try:
req = app.state.request_queue.sync_q.get_nowait()
if req is not None:
batch.append(req)
except queue.Empty:
break
output_tokens = app.state.model.batch_infer_one_round(batch)
for task, token in zip(batch, output_tokens):
task.output(token)
if task.finish_reason is None:
app.state.request_queue.sync_q.put(task)
else:
print(f"[INFO] Task {task.id} finished infer.")
app.state.kv_cache_pool.release_sync(task)
def build_task(id_, request_data, request: Request):
messages = request_data.get("messages", [])
input_content = request.app.state.model.tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
tokenize=False,
)
tokens = request.app.state.model.tokenizer.encode(input_content)
return AsyncInferTask(
id_,
tokens,
request_data.get("max_tokens", request.app.state.model.max_context_len()),
request_data.get("temperature", 1.0),
request_data.get("top_k", 1),
request_data.get("top_p", 1.0),
request.app.state.model.eos_token_id,
)
async def chat_stream(id_, request_data, request: Request): async def chat_stream(id_, request_data, request: Request):
try: try:
await request_lock.acquire() infer_task = build_task(id_, request_data, request)
await request.app.state.kv_cache_pool.acquire(infer_task)
# Initial empty content
chunk = json.dumps( chunk = json.dumps(
chunk_json(id_, content="", role="assistant"), chunk_json(id_, content="", role="assistant"), ensure_ascii=False
ensure_ascii=False,
) )
yield f"{chunk}\n\n" yield f"{chunk}\n\n"
async for token in model.chat_stream_async(request_data, kv_cache): request.app.state.request_queue.sync_q.put(infer_task)
while True:
if await request.is_disconnected(): if await request.is_disconnected():
print("Client disconnected. Aborting stream.") print("Client disconnected. Aborting stream.")
break break
chunk = json.dumps( if (
chunk_json(id_, content=token), infer_task.finish_reason is not None
ensure_ascii=False, and infer_task.output_queue.async_q.empty()
):
chunk = json.dumps(
chunk_json(id_, finish_reason=infer_task.finish_reason),
ensure_ascii=False,
)
yield f"{chunk}\n\n"
break
token = await infer_task.output_queue.async_q.get()
content = (
request.app.state.model.tokenizer._tokenizer.id_to_token(token)
.replace("▁", " ")
.replace("<0x0A>", "\n")
) )
chunk = json.dumps(chunk_json(id_, content=content), ensure_ascii=False)
yield f"{chunk}\n\n" yield f"{chunk}\n\n"
except Exception as e:
print(f"[Error] ID : {id_} Exception: {e}")
finally: finally:
if request_lock.locked(): if infer_task.finish_reason is None:
request_lock.release() infer_task.finish_reason = "cancel"
chunk = json.dumps(
chunk_json(id_, finish_reason="stop"),
ensure_ascii=False,
)
yield f"{chunk}\n\n"
def chat(id_, request_data): async def chat(id_, request_data, request: Request):
output_text = model.chat( try:
request_data, infer_task = build_task(id_, request_data, request)
kv_cache, await request.app.state.kv_cache_pool.acquire(infer_task)
) request.app.state.request_queue.sync_q.put(infer_task)
response = chunk_json( output = []
id_, content=output_text.strip(), role="assistant", finish_reason="stop" while True:
) if (
return JSONResponse(response) infer_task.finish_reason is not None
and infer_task.output_queue.async_q.empty()
):
break
token = await infer_task.output_queue.async_q.get()
content = (
request.app.state.model.tokenizer._tokenizer.id_to_token(token)
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
output.append(content)
output_text = "".join(output).strip()
response = chunk_json(
id_,
content=output_text,
role="assistant",
finish_reason=infer_task.finish_reason or "stop",
)
return response
except Exception as e:
print(f"[Error] ID: {id_} Exception: {e}")
return JSONResponse(content={"error": str(e)}, status_code=500)
finally:
if infer_task.finish_reason is None:
infer_task.finish_reason = "cancel"
@app.post("/chat/completions") @App.post("/chat/completions")
async def chat_completions(request: Request): async def chat_completions(request: Request):
data = await request.json() data = await request.json()
...@@ -134,11 +250,11 @@ async def chat_completions(request: Request): ...@@ -134,11 +250,11 @@ async def chat_completions(request: Request):
chat_stream(id_, data, request), media_type="text/event-stream" chat_stream(id_, data, request), media_type="text/event-stream"
) )
else: else:
return chat(id_, data) return JSONResponse(chat(id_, data))
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) uvicorn.run(App, host="0.0.0.0", port=8000)
""" """
curl -N -H "Content-Type: application/json" \ curl -N -H "Content-Type: application/json" \
......
...@@ -35,7 +35,7 @@ class DeviceType(ctypes.c_int): ...@@ -35,7 +35,7 @@ class DeviceType(ctypes.c_int):
DEVICE_TYPE_MOORE = 5 DEVICE_TYPE_MOORE = 5
class JiugeMeta(ctypes.Structure): class JiugeMetaCStruct(ctypes.Structure):
_fields_ = [ _fields_ = [
("dt_logits", DataType), ("dt_logits", DataType),
("nlayer", c_size_t), ("nlayer", c_size_t),
...@@ -53,7 +53,7 @@ class JiugeMeta(ctypes.Structure): ...@@ -53,7 +53,7 @@ class JiugeMeta(ctypes.Structure):
# Define the JiugeWeights struct # Define the JiugeWeights struct
class JiugeWeights(ctypes.Structure): class JiugeWeightsCStruct(ctypes.Structure):
_fields_ = [ _fields_ = [
("nlayer", c_size_t), ("nlayer", c_size_t),
("dt_norm", DataType), ("dt_norm", DataType),
...@@ -72,11 +72,11 @@ class JiugeWeights(ctypes.Structure): ...@@ -72,11 +72,11 @@ class JiugeWeights(ctypes.Structure):
] ]
class JiugeModel(ctypes.Structure): class JiugeModelCSruct(ctypes.Structure):
pass pass
class KVCache(ctypes.Structure): class KVCacheCStruct(ctypes.Structure):
pass pass
...@@ -85,30 +85,31 @@ def __open_library__(): ...@@ -85,30 +85,31 @@ def __open_library__():
os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so" os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so"
) )
lib = ctypes.CDLL(lib_path) lib = ctypes.CDLL(lib_path)
lib.createJiugeModel.restype = POINTER(JiugeModel) lib.createJiugeModel.restype = POINTER(JiugeModelCSruct)
lib.createJiugeModel.argtypes = [ lib.createJiugeModel.argtypes = [
POINTER(JiugeMeta), # JiugeMeta const * POINTER(JiugeMetaCStruct), # JiugeMeta const *
POINTER(JiugeWeights), # JiugeWeights const * POINTER(JiugeWeightsCStruct), # JiugeWeights const *
DeviceType, # DeviceType DeviceType, # DeviceType
c_int, # int ndev c_int, # int ndev
POINTER(c_int), # int const *dev_ids POINTER(c_int), # int const *dev_ids
] ]
lib.destroyJiugeModel.argtypes = [POINTER(JiugeModel)] lib.destroyJiugeModel.argtypes = [POINTER(JiugeModelCSruct)]
lib.createKVCache.restype = POINTER(KVCache) lib.createKVCache.argtypes = [POINTER(JiugeModelCSruct)]
lib.dropKVCache.argtypes = [POINTER(JiugeModel), POINTER(KVCache)] lib.createKVCache.restype = POINTER(KVCacheCStruct)
lib.dropKVCache.argtypes = [POINTER(JiugeModelCSruct), POINTER(KVCacheCStruct)]
lib.inferBatch.restype = None lib.inferBatch.restype = None
lib.inferBatch.argtypes = [ lib.inferBatch.argtypes = [
POINTER(JiugeModel), # struct JiugeModel const * POINTER(JiugeModelCSruct), # struct JiugeModel const *
POINTER(c_uint), # unsigned int const *tokens POINTER(c_uint), # unsigned int const *tokens
c_uint, # unsigned int ntok c_uint, # unsigned int ntok
POINTER(c_uint), # unsigned int const *req_lens POINTER(c_uint), # unsigned int const *req_lens
c_uint, # unsigned int nreq c_uint, # unsigned int nreq
POINTER(c_uint), # unsigned int const *req_pos POINTER(c_uint), # unsigned int const *req_pos
POINTER(POINTER(KVCache)), # struct KVCache **kv_caches POINTER(POINTER(KVCacheCStruct)), # struct KVCache **kv_caches
POINTER(c_float), # float temperature
POINTER(c_uint), # unsigned int topk
POINTER(c_float), # float topp
POINTER(c_uint), # unsigned int *output POINTER(c_uint), # unsigned int *output
c_float, # float temperature
c_uint, # unsigned int topk
c_float, # float topp
] ]
return lib return lib
......
...@@ -5,13 +5,14 @@ from concurrent.futures import ThreadPoolExecutor, as_completed ...@@ -5,13 +5,14 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
API_URL = "http://localhost:8000/chat/completions" API_URL = "http://localhost:8000/chat/completions"
MODEL = "FM9G-7B" MODEL = "FM9G-7B"
PROMPT = ["给我讲个故事", "山东最高的山是?"] PROMPT = ["山东最高的山是?", "给我讲个故事"]
CONCURRENCY = 10 # 并发用户数量 CONCURRENCY = 10 # 并发用户数量
def single_run(user_id): def single_run(user_id):
payload = { payload = {
"model": MODEL, "model": MODEL,
"messages": [{"role": "user", "content": PROMPT[user_id % len(PROMPT)]}], "messages": [{"role": "user", "content": PROMPT[user_id % len(PROMPT)]}],
"max_tokens": 512,
"stream": True "stream": True
} }
headers = {'Content-Type': 'application/json', 'Accept': 'application/json'} headers = {'Content-Type': 'application/json', 'Accept': 'application/json'}
...@@ -86,6 +87,9 @@ def main(): ...@@ -86,6 +87,9 @@ def main():
if r['stream_time'] < best_stream: if r['stream_time'] < best_stream:
best_stream = r['stream_time'] best_stream = r['stream_time']
best = r best = r
# Sort results by user ID
results.sort(key=lambda x: x["user"])
with open("responses.txt", "w", encoding="utf-8") as fw: with open("responses.txt", "w", encoding="utf-8") as fw:
for r in results: for r in results:
......
...@@ -115,8 +115,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -115,8 +115,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches, struct KVCache **kv_caches,
uint32_t *ans, const float *temperature, const uint32_t *topk, const float *topp,
float temperature, uint32_t topk, float topp) { uint32_t *output) {
auto nlayer = meta.nlayer; auto nlayer = meta.nlayer;
auto nkvh = meta.nkvh / ndev; auto nkvh = meta.nkvh / ndev;
auto nh = meta.nh / ndev; auto nh = meta.nh / ndev;
...@@ -457,8 +457,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -457,8 +457,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI(infiniopRandomSample( RUN_INFINI(infiniopRandomSample(
desc_sample, workspace, workspace_size, desc_sample, workspace, workspace_size,
result_buf->data(req), result_buf->data(req),
prob_buf->data(req * dvoc), random_val, topp, prob_buf->data(req * dvoc),
topk, temperature, stream)); random_val,
topp[req], topk[req], temperature[req],
stream));
// result_buf->debug(); // result_buf->debug();
token_offset += seq_len; token_offset += seq_len;
} }
...@@ -466,7 +468,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -466,7 +468,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(), RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(),
sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H)); sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H));
for (uint32_t req = 0; req < nreq; req++) { for (uint32_t req = 0; req < nreq; req++) {
ans[req] = result_cpu[req]; output[req] = result_cpu[req];
} }
} }
...@@ -500,15 +502,15 @@ inferBatch(struct JiugeModel *model, ...@@ -500,15 +502,15 @@ inferBatch(struct JiugeModel *model,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches, struct KVCache **kv_caches,
uint32_t *ans, const float *temperature, const uint32_t *topk, const float *topp,
float temperature, uint32_t topk, float topp) { uint32_t *output) {
model->req.tokens = tokens; model->req.tokens = tokens;
model->req.ntok = ntok; model->req.ntok = ntok;
model->req.req_lens = req_lens; model->req.req_lens = req_lens;
model->req.nreq = nreq; model->req.nreq = nreq;
model->req.req_pos = req_pos; model->req.req_pos = req_pos;
model->req.kv_caches = kv_caches; model->req.kv_caches = kv_caches;
model->req.ans = ans; model->req.output = output;
model->req.temperature = temperature; model->req.temperature = temperature;
model->req.topk = topk; model->req.topk = topk;
model->req.topp = topp; model->req.topp = topp;
...@@ -547,7 +549,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso ...@@ -547,7 +549,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
break; break;
} }
inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, req.req_lens, req.nreq, req.req_pos, req.kv_caches, req.ans, req.temperature, req.topk, req.topp); inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, req.req_lens, req.nreq, req.req_pos, req.kv_caches, req.temperature, req.topk, req.topp, req.output);
state.proceed = false; state.proceed = false;
lock.unlock(); lock.unlock();
......
...@@ -45,10 +45,10 @@ struct InferRequest { ...@@ -45,10 +45,10 @@ struct InferRequest {
uint32_t nreq; uint32_t nreq;
const uint32_t *req_pos; const uint32_t *req_pos;
struct KVCache **kv_caches; struct KVCache **kv_caches;
uint32_t *ans; const float *temperature;
float temperature; const uint32_t *topk;
uint32_t topk; const float *topp;
float topp; uint32_t *output;
}; };
struct JiugeModel { struct JiugeModel {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment