"git@developer.sourcefind.cn:jerrrrry/infinilm.git" did not exist on "6ae4832227c8d6a1f20f9da76685f9fd74138468"
Commit 2b9ce5a6 authored by Pan Zezhong's avatar Pan Zezhong
Browse files

refactor task to avoid model dependency on async libs

parent cfc8b598
import janus
class InferTask: class InferTask:
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens):
self.id = id self.id = id
...@@ -11,21 +8,24 @@ class InferTask: ...@@ -11,21 +8,24 @@ class InferTask:
self.topk = topk self.topk = topk
self.topp = topp self.topp = topp
self.end_tokens = end_tokens self.end_tokens = end_tokens
self.output_queue = janus.Queue() self._kv_cache = None
self._kv_cache_pool_item = None
self.pos = 0 self.pos = 0
print(f"[INFO] Create InferTask {self.id}")
def bind_kvcache(self, kv_cache_pool_item, pos): def bind_kvcache(self, kv_cache, pos=0):
self._kv_cache_pool_item = kv_cache_pool_item self._kv_cache = kv_cache
self.pos = pos self.pos = pos
self.tokens = self.tokens[pos:] self.tokens = self.tokens[pos:]
def release_kvcache(self):
cache = self._kv_cache
self._kv_cache = None
return cache
def kvcache(self): def kvcache(self):
return self._kv_cache_pool_item.kvcache return self._kv_cache
def output(self, out_token): def next(self, out_token):
self._kv_cache_pool_item.update_tokens(self.tokens, self.pos) self._kv_cache.update_tokens(self.tokens, self.pos)
self.pos += len(self.tokens) self.pos += len(self.tokens)
if out_token == None or out_token in self.end_tokens: if out_token == None or out_token in self.end_tokens:
...@@ -35,4 +35,25 @@ class InferTask: ...@@ -35,4 +35,25 @@ class InferTask:
else: else:
self.tokens = [out_token] self.tokens = [out_token]
self.output_queue.sync_q.put(out_token)
class KVCache:
def __init__(self, model):
self._kvcache = model.create_kv_cache()
self.tokens = [0 for _ in range(model.max_context_len())]
def data(self):
return self._kvcache
def drop(self, model):
model.drop_kv_cache(self._kvcache)
def update_tokens(self, tokens, pos):
end = pos + len(tokens)
max_len = len(self.tokens)
# If overflow, truncate tokens to fit
if end > max_len:
tokens = tokens[: max_len - pos]
end = max_len
self.tokens[pos:end] = tokens
from typing import List from typing import List
from libinfinicore_infer import ( from libinfinicore_infer import (
JiugeMeta, JiugeMetaCStruct,
JiugeWeights, JiugeWeightsCStruct,
KVCache, KVCacheCStruct,
DataType, DataType,
DeviceType, DeviceType,
create_jiuge_model, create_jiuge_model,
...@@ -11,7 +11,7 @@ from libinfinicore_infer import ( ...@@ -11,7 +11,7 @@ from libinfinicore_infer import (
drop_kv_cache, drop_kv_cache,
infer_batch, infer_batch,
) )
from infer_task import InferTask from infer_task import InferTask, KVCache
from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import os import os
...@@ -25,6 +25,7 @@ import transformers ...@@ -25,6 +25,7 @@ import transformers
torch.set_default_device("cpu") torch.set_default_device("cpu")
class LlamaWeightsNaming: class LlamaWeightsNaming:
def input_embd(self): def input_embd(self):
return "model.embed_tokens.weight" return "model.embed_tokens.weight"
...@@ -78,7 +79,7 @@ class LlamaWeightsNaming: ...@@ -78,7 +79,7 @@ class LlamaWeightsNaming:
) )
class JiugeMetaFromLlama(JiugeMeta): class JiugeMetaFromLlama(JiugeMetaCStruct):
def __init__(self, config, dtype=torch.float16): def __init__(self, config, dtype=torch.float16):
if dtype == torch.float16: if dtype == torch.float16:
dt_ = DataType.INFINI_DTYPE_F16 dt_ = DataType.INFINI_DTYPE_F16
...@@ -107,7 +108,7 @@ class JiugeMetaFromLlama(JiugeMeta): ...@@ -107,7 +108,7 @@ class JiugeMetaFromLlama(JiugeMeta):
self.torch_dtype_logits = dtype self.torch_dtype_logits = dtype
class JiugeWeightsImpl(JiugeWeights): class JiugeWeightsImpl(JiugeWeightsCStruct):
def __init__( def __init__(
self, self,
meta, meta,
...@@ -160,7 +161,9 @@ class JiugeWeightsImpl(JiugeWeights): ...@@ -160,7 +161,9 @@ class JiugeWeightsImpl(JiugeWeights):
self.output_norm = self.output_norm_tensor.data_ptr() self.output_norm = self.output_norm_tensor.data_ptr()
self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat)
if not transpose_weight: if not transpose_weight:
self.output_embd_tensor = self.output_embd_tensor.transpose(0, 1).contiguous() self.output_embd_tensor = self.output_embd_tensor.transpose(
0, 1
).contiguous()
self.output_embd = self.output_embd_tensor.data_ptr() self.output_embd = self.output_embd_tensor.data_ptr()
self.attn_norm_tensors = [ self.attn_norm_tensors = [
...@@ -197,7 +200,12 @@ class JiugeWeightsImpl(JiugeWeights): ...@@ -197,7 +200,12 @@ class JiugeWeightsImpl(JiugeWeights):
] ]
if not transpose_weight: if not transpose_weight:
for i in range(nlayer): for i in range(nlayer):
self.qkv_tensor[i] = self.qkv_tensor[i].reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d).transpose(1, 2).contiguous() self.qkv_tensor[i] = (
self.qkv_tensor[i]
.reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d)
.transpose(1, 2)
.contiguous()
)
self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)] self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)]
self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs) self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs)
...@@ -234,13 +242,18 @@ class JiugeWeightsImpl(JiugeWeights): ...@@ -234,13 +242,18 @@ class JiugeWeightsImpl(JiugeWeights):
self.attn_qkv_b = None self.attn_qkv_b = None
self.attn_o_tensor = [ self.attn_o_tensor = [
state_dict[naming.attn_o(i)] (
state_dict[naming.attn_o(i)]
.to(torch_dt_mat) .to(torch_dt_mat)
.reshape([d, ndev, nh // ndev * dh]) .reshape([d, ndev, nh // ndev * dh])
.transpose(0, 1) .transpose(0, 1)
.contiguous() .contiguous()
if transpose_weight if transpose_weight
else state_dict[naming.attn_o(i)].transpose(0, 1).to(torch_dt_mat).contiguous() else state_dict[naming.attn_o(i)]
.transpose(0, 1)
.to(torch_dt_mat)
.contiguous()
)
for i in range(nlayer) for i in range(nlayer)
] ]
self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)] self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)]
...@@ -269,18 +282,28 @@ class JiugeWeightsImpl(JiugeWeights): ...@@ -269,18 +282,28 @@ class JiugeWeightsImpl(JiugeWeights):
] ]
if not transpose_weight: if not transpose_weight:
for i in range(nlayer): for i in range(nlayer):
self.gate_up_tensors[i] = self.gate_up_tensors[i].reshape(ndev, 2 * di // ndev, d).transpose(1, 2).contiguous() self.gate_up_tensors[i] = (
self.gate_up_tensors[i]
.reshape(ndev, 2 * di // ndev, d)
.transpose(1, 2)
.contiguous()
)
self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)] self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)]
self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs) self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs)
self.ffn_down_tensor = [ self.ffn_down_tensor = [
state_dict[naming.down(i)] (
.to(torch_dt_mat) state_dict[naming.down(i)]
.reshape([d, ndev, di // ndev]) .to(torch_dt_mat)
.transpose(0, 1) .reshape([d, ndev, di // ndev])
.contiguous() .transpose(0, 1)
if transpose_weight .contiguous()
else state_dict[naming.down(i)].transpose(0, 1).to(torch_dt_mat).contiguous() if transpose_weight
else state_dict[naming.down(i)]
.transpose(0, 1)
.to(torch_dt_mat)
.contiguous()
)
for i in range(nlayer) for i in range(nlayer)
] ]
self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)] self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)]
...@@ -296,7 +319,7 @@ class JiugeBatchedTask: ...@@ -296,7 +319,7 @@ class JiugeBatchedTask:
token_lists = [t.tokens for t in tasks] token_lists = [t.tokens for t in tasks]
self.req_lens_list = [len(toks) for toks in token_lists] self.req_lens_list = [len(toks) for toks in token_lists]
self.req_pos_list = [t.pos for t in tasks] self.req_pos_list = [t.pos for t in tasks]
self.kv_cache_ptrs = [t.kvcache() for t in tasks] self.kv_cache_ptrs = [t.kvcache().data() for t in tasks]
self.temperaturas_list = [t.temperature for t in tasks] self.temperaturas_list = [t.temperature for t in tasks]
self.topks_list = [t.topk for t in tasks] self.topks_list = [t.topk for t in tasks]
self.topps_list = [t.topp for t in tasks] self.topps_list = [t.topp for t in tasks]
...@@ -309,7 +332,7 @@ class JiugeBatchedTask: ...@@ -309,7 +332,7 @@ class JiugeBatchedTask:
self.tokens = (c_uint * self.ntok)(*flat_tokens) self.tokens = (c_uint * self.ntok)(*flat_tokens)
self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) self.req_lens = (c_uint * self.nreq)(*self.req_lens_list)
self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) self.req_pos = (c_uint * self.nreq)(*self.req_pos_list)
self.kv_caches = (POINTER(KVCache) * self.nreq)(*self.kv_cache_ptrs) self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs)
self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list)
self.topks = (c_uint * self.nreq)(*self.topks_list) self.topks = (c_uint * self.nreq)(*self.topks_list)
self.topps = (c_float * self.nreq)(*self.topps_list) self.topps = (c_float * self.nreq)(*self.topps_list)
...@@ -338,7 +361,7 @@ class JiugeForCauslLM: ...@@ -338,7 +361,7 @@ class JiugeForCauslLM:
for name_ in data_.keys(): for name_ in data_.keys():
tensors_[name_] = data_.get_tensor(name_) tensors_[name_] = data_.get_tensor(name_)
return tensors_ return tensors_
print("Loading model weights to host...") print("Loading model weights to host...")
load_start_time = time.time() load_start_time = time.time()
...@@ -346,26 +369,46 @@ class JiugeForCauslLM: ...@@ -346,26 +369,46 @@ class JiugeForCauslLM:
config = json.load(f) config = json.load(f)
self.config = config self.config = config
eos_token_id = self.config["eos_token_id"] eos_token_id = self.config["eos_token_id"]
self.eos_token_id = [eos_token_id] if type(eos_token_id) == int else eos_token_id self.eos_token_id = (
transpose_weight = device != DeviceType.DEVICE_TYPE_ASCEND # y = xW is faster than y=xW^T on Ascend [eos_token_id] if type(eos_token_id) == int else eos_token_id
)
transpose_weight = (
device != DeviceType.DEVICE_TYPE_ASCEND
) # y = xW is faster than y=xW^T on Ascend
if "llama" == config["model_type"]: if "llama" == config["model_type"]:
model = transformers.LlamaForCausalLM.from_pretrained(model_dir_path).cpu().half() model = (
transformers.LlamaForCausalLM.from_pretrained(model_dir_path)
.cpu()
.half()
)
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), model.state_dict(), ndev=ndev, transpose_weight=transpose_weight self.meta,
LlamaWeightsNaming(),
model.state_dict(),
ndev=ndev,
transpose_weight=transpose_weight,
) )
elif "fm9g" == config["model_type"]: elif "fm9g" == config["model_type"]:
if any(file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()): if any(
file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()
):
state_dict = load_all_safetensors_from_dir(model_dir_path) state_dict = load_all_safetensors_from_dir(model_dir_path)
else: else:
state_dict = torch.load( state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu" os.path.join(model_dir_path, "pytorch_model.bin"),
) weights_only=True,
map_location="cpu",
)
if LlamaWeightsNaming.match(state_dict): if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight self.meta,
LlamaWeightsNaming(),
state_dict,
ndev=ndev,
transpose_weight=transpose_weight,
) )
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True model_dir_path, trust_remote_code=True
...@@ -374,12 +417,18 @@ class JiugeForCauslLM: ...@@ -374,12 +417,18 @@ class JiugeForCauslLM:
raise ValueError("Unsupported weight naming") raise ValueError("Unsupported weight naming")
elif "fm9g7b" == config["model_type"]: elif "fm9g7b" == config["model_type"]:
state_dict = torch.load( state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu" os.path.join(model_dir_path, "pytorch_model.bin"),
weights_only=True,
map_location="cpu",
) )
if LlamaWeightsNaming.match(state_dict): if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight self.meta,
LlamaWeightsNaming(),
state_dict,
ndev=ndev,
transpose_weight=transpose_weight,
) )
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True model_dir_path, trust_remote_code=True
...@@ -391,7 +440,11 @@ class JiugeForCauslLM: ...@@ -391,7 +440,11 @@ class JiugeForCauslLM:
if LlamaWeightsNaming.match(state_dict): if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight self.meta,
LlamaWeightsNaming(),
state_dict,
ndev=ndev,
transpose_weight=transpose_weight,
) )
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path model_dir_path
...@@ -401,7 +454,7 @@ class JiugeForCauslLM: ...@@ -401,7 +454,7 @@ class JiugeForCauslLM:
load_end_time = time.time() load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s") print(f"Time used: {load_end_time - load_start_time:.3f}s")
print(f"Creating model on {ndev} devices...") print(f"Creating model on {ndev} devices...")
load_start_time = time.time() load_start_time = time.time()
dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
...@@ -414,10 +467,10 @@ class JiugeForCauslLM: ...@@ -414,10 +467,10 @@ class JiugeForCauslLM:
) )
load_end_time = time.time() load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s") print(f"Time used: {load_end_time - load_start_time:.3f}s")
def max_context_len(self): def max_context_len(self):
return self.meta.dctx return self.meta.dctx
def create_kv_cache(self): def create_kv_cache(self):
return create_kv_cache(self.model_instance) return create_kv_cache(self.model_instance)
...@@ -435,7 +488,6 @@ class JiugeForCauslLM: ...@@ -435,7 +488,6 @@ class JiugeForCauslLM:
return list(output) return list(output)
def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0): def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0):
kv_cache = create_kv_cache(self.model_instance)
input_content = self.tokenizer.apply_chat_template( input_content = self.tokenizer.apply_chat_template(
conversation=[{"role": "user", "content": input_content}], conversation=[{"role": "user", "content": input_content}],
add_generation_prompt=True, add_generation_prompt=True,
...@@ -443,39 +495,26 @@ class JiugeForCauslLM: ...@@ -443,39 +495,26 @@ class JiugeForCauslLM:
) )
print(input_content, end="", flush=True) print(input_content, end="", flush=True)
tokens = self.tokenizer.encode(input_content) tokens = self.tokenizer.encode(input_content)
ntok = len(tokens) infer_task = InferTask(
nreq = 1 0,
output_content = "" tokens,
tokens = (c_uint * ntok)(*tokens) self.max_context_len(),
req_lens = (c_uint * nreq)(*[ntok]) temperature_,
req_pos = (c_uint * nreq)(*[0]) topk_,
kv_caches = (POINTER(KVCache) * nreq)(*[kv_cache]) topp_,
ans = (c_uint * nreq)() self.eos_token_id,
temperature = (c_float * nreq)(*[temperature_]) )
topk = (c_uint * nreq)(*[topk_]) infer_task.bind_kvcache(KVCache(self))
topp = (c_float * nreq)(*[topp_])
steps = 0 steps = 0
total_time = 0 total_time = 0
output_content = ""
for step_i in range(max_steps): for step_i in range(max_steps):
start_time = time.time() start_time = time.time()
infer_batch( output_tokens = self.batch_infer_one_round([infer_task])
self.model_instance,
tokens,
ntok,
req_lens,
nreq,
req_pos,
kv_caches,
temperature,
topk,
topp,
ans,
)
steps += 1
output_tokens = list(ans)
end_time = time.time() end_time = time.time()
steps += 1
output_str = ( output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0]) self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ") .replace("▁", " ")
...@@ -485,10 +524,7 @@ class JiugeForCauslLM: ...@@ -485,10 +524,7 @@ class JiugeForCauslLM:
print(output_str, end="", flush=True) print(output_str, end="", flush=True)
if output_tokens[0] in self.eos_token_id: if output_tokens[0] in self.eos_token_id:
break break
req_pos[0] = req_pos[0] + ntok infer_task.next(output_tokens[0])
ntok = 1
tokens = (c_uint * ntok)(*output_tokens)
req_lens = (c_uint * nreq)(*[ntok])
if step_i > 0: if step_i > 0:
total_time += end_time - start_time total_time += end_time - start_time
...@@ -496,8 +532,8 @@ class JiugeForCauslLM: ...@@ -496,8 +532,8 @@ class JiugeForCauslLM:
print("\n") print("\n")
avg_time = total_time * 1000 / (steps - 1) avg_time = total_time * 1000 / (steps - 1)
print(f"Time per step: {avg_time:.3f}ms") print(f"Time per step: {avg_time:.3f}ms")
for kv_cache in kv_caches:
drop_kv_cache(self.model_instance, kv_cache) infer_task._kv_cache.drop(self)
return output_content, avg_time return output_content, avg_time
def destroy_model_instance(self): def destroy_model_instance(self):
......
from infer_task import KVCache
import asyncio import asyncio
from typing import List from typing import List
class KVCachePoolItem:
def __init__(self, model):
self.kvcache = model.create_kv_cache()
self.tokens = [0 for _ in range(model.max_context_len())]
def drop(self, model):
model.drop_kv_cache(self.kvcache)
def update_tokens(self, tokens, pos):
end = pos + len(tokens)
max_len = len(self.tokens)
# If overflow, truncate tokens to fit
if end > max_len:
tokens = tokens[: max_len - pos]
end = max_len
self.tokens[pos:end] = tokens
import threading import threading
...@@ -29,7 +9,7 @@ class KVCachePool: ...@@ -29,7 +9,7 @@ class KVCachePool:
def __init__(self, model, max_caches: int = 32): def __init__(self, model, max_caches: int = 32):
self.max_caches = max_caches self.max_caches = max_caches
self.model = model self.model = model
self._available: List[KVCachePoolItem] = [] self._available: List[KVCache] = []
self.num_caches = len(self._available) self.num_caches = len(self._available)
self._lock = threading.Lock() self._lock = threading.Lock()
self._not_empty = threading.Condition(self._lock) self._not_empty = threading.Condition(self._lock)
...@@ -45,8 +25,10 @@ class KVCachePool: ...@@ -45,8 +25,10 @@ class KVCachePool:
if len(self._available) == 0: if len(self._available) == 0:
if self.num_caches < self.max_caches: if self.num_caches < self.max_caches:
self.num_caches += 1 self.num_caches += 1
print(f"[INFO] Task {infer_task.id} created new KVCachePoolItem") print(
return infer_task.bind_kvcache(KVCachePoolItem(self.model), 0) f"[INFO] Task {infer_task.id} created new KVCachePoolItem"
)
return infer_task.bind_kvcache(KVCache(self.model), 0)
else: else:
self._not_empty.wait() self._not_empty.wait()
else: else:
...@@ -62,8 +44,7 @@ class KVCachePool: ...@@ -62,8 +44,7 @@ class KVCachePool:
def release_sync(self, infer_task): def release_sync(self, infer_task):
with self._not_empty: with self._not_empty:
print(f"[INFO] Task {infer_task.id} returned KVCachePoolItem to pool") print(f"[INFO] Task {infer_task.id} returned KVCachePoolItem to pool")
self._available.append(infer_task._kv_cache_pool_item) self._available.append(infer_task.release_kvcache())
infer_task._kv_cache_pool_item = None
self._not_empty.notify() self._not_empty.notify()
async def acquire(self, infer_task): async def acquire(self, infer_task):
......
...@@ -70,6 +70,17 @@ print( ...@@ -70,6 +70,17 @@ print(
f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs."
) )
# A wrapper for InferTask that supports async output queue
class AsyncInferTask(InferTask):
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens):
super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens)
self.output_queue = janus.Queue()
print(f"[INFO] Create InferTask {self.id}")
def output(self, out_token):
self.next(out_token)
self.output_queue.sync_q.put(out_token)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
...@@ -132,7 +143,7 @@ def build_task(id_, request_data, request: Request): ...@@ -132,7 +143,7 @@ def build_task(id_, request_data, request: Request):
tokenize=False, tokenize=False,
) )
tokens = request.app.state.model.tokenizer.encode(input_content) tokens = request.app.state.model.tokenizer.encode(input_content)
return InferTask( return AsyncInferTask(
id_, id_,
tokens, tokens,
request_data.get("max_tokens", request.app.state.model.max_context_len()), request_data.get("max_tokens", request.app.state.model.max_context_len()),
......
...@@ -35,7 +35,7 @@ class DeviceType(ctypes.c_int): ...@@ -35,7 +35,7 @@ class DeviceType(ctypes.c_int):
DEVICE_TYPE_MOORE = 5 DEVICE_TYPE_MOORE = 5
class JiugeMeta(ctypes.Structure): class JiugeMetaCStruct(ctypes.Structure):
_fields_ = [ _fields_ = [
("dt_logits", DataType), ("dt_logits", DataType),
("nlayer", c_size_t), ("nlayer", c_size_t),
...@@ -53,7 +53,7 @@ class JiugeMeta(ctypes.Structure): ...@@ -53,7 +53,7 @@ class JiugeMeta(ctypes.Structure):
# Define the JiugeWeights struct # Define the JiugeWeights struct
class JiugeWeights(ctypes.Structure): class JiugeWeightsCStruct(ctypes.Structure):
_fields_ = [ _fields_ = [
("nlayer", c_size_t), ("nlayer", c_size_t),
("dt_norm", DataType), ("dt_norm", DataType),
...@@ -72,11 +72,11 @@ class JiugeWeights(ctypes.Structure): ...@@ -72,11 +72,11 @@ class JiugeWeights(ctypes.Structure):
] ]
class JiugeModel(ctypes.Structure): class JiugeModelCSruct(ctypes.Structure):
pass pass
class KVCache(ctypes.Structure): class KVCacheCStruct(ctypes.Structure):
pass pass
...@@ -85,27 +85,27 @@ def __open_library__(): ...@@ -85,27 +85,27 @@ def __open_library__():
os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so" os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so"
) )
lib = ctypes.CDLL(lib_path) lib = ctypes.CDLL(lib_path)
lib.createJiugeModel.restype = POINTER(JiugeModel) lib.createJiugeModel.restype = POINTER(JiugeModelCSruct)
lib.createJiugeModel.argtypes = [ lib.createJiugeModel.argtypes = [
POINTER(JiugeMeta), # JiugeMeta const * POINTER(JiugeMetaCStruct), # JiugeMeta const *
POINTER(JiugeWeights), # JiugeWeights const * POINTER(JiugeWeightsCStruct), # JiugeWeights const *
DeviceType, # DeviceType DeviceType, # DeviceType
c_int, # int ndev c_int, # int ndev
POINTER(c_int), # int const *dev_ids POINTER(c_int), # int const *dev_ids
] ]
lib.destroyJiugeModel.argtypes = [POINTER(JiugeModel)] lib.destroyJiugeModel.argtypes = [POINTER(JiugeModelCSruct)]
lib.createKVCache.argtypes = [POINTER(JiugeModel)] lib.createKVCache.argtypes = [POINTER(JiugeModelCSruct)]
lib.createKVCache.restype = POINTER(KVCache) lib.createKVCache.restype = POINTER(KVCacheCStruct)
lib.dropKVCache.argtypes = [POINTER(JiugeModel), POINTER(KVCache)] lib.dropKVCache.argtypes = [POINTER(JiugeModelCSruct), POINTER(KVCacheCStruct)]
lib.inferBatch.restype = None lib.inferBatch.restype = None
lib.inferBatch.argtypes = [ lib.inferBatch.argtypes = [
POINTER(JiugeModel), # struct JiugeModel const * POINTER(JiugeModelCSruct), # struct JiugeModel const *
POINTER(c_uint), # unsigned int const *tokens POINTER(c_uint), # unsigned int const *tokens
c_uint, # unsigned int ntok c_uint, # unsigned int ntok
POINTER(c_uint), # unsigned int const *req_lens POINTER(c_uint), # unsigned int const *req_lens
c_uint, # unsigned int nreq c_uint, # unsigned int nreq
POINTER(c_uint), # unsigned int const *req_pos POINTER(c_uint), # unsigned int const *req_pos
POINTER(POINTER(KVCache)), # struct KVCache **kv_caches POINTER(POINTER(KVCacheCStruct)), # struct KVCache **kv_caches
POINTER(c_float), # float temperature POINTER(c_float), # float temperature
POINTER(c_uint), # unsigned int topk POINTER(c_uint), # unsigned int topk
POINTER(c_float), # float topp POINTER(c_float), # float topp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment