Commit 2b9ce5a6 authored by Pan Zezhong's avatar Pan Zezhong
Browse files

refactor task to avoid model dependency on async libs

parent cfc8b598
import janus
class InferTask:
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens):
self.id = id
......@@ -11,21 +8,24 @@ class InferTask:
self.topk = topk
self.topp = topp
self.end_tokens = end_tokens
self.output_queue = janus.Queue()
self._kv_cache_pool_item = None
self._kv_cache = None
self.pos = 0
print(f"[INFO] Create InferTask {self.id}")
def bind_kvcache(self, kv_cache_pool_item, pos):
self._kv_cache_pool_item = kv_cache_pool_item
def bind_kvcache(self, kv_cache, pos=0):
self._kv_cache = kv_cache
self.pos = pos
self.tokens = self.tokens[pos:]
def release_kvcache(self):
cache = self._kv_cache
self._kv_cache = None
return cache
def kvcache(self):
return self._kv_cache_pool_item.kvcache
return self._kv_cache
def output(self, out_token):
self._kv_cache_pool_item.update_tokens(self.tokens, self.pos)
def next(self, out_token):
self._kv_cache.update_tokens(self.tokens, self.pos)
self.pos += len(self.tokens)
if out_token == None or out_token in self.end_tokens:
......@@ -35,4 +35,25 @@ class InferTask:
else:
self.tokens = [out_token]
self.output_queue.sync_q.put(out_token)
class KVCache:
def __init__(self, model):
self._kvcache = model.create_kv_cache()
self.tokens = [0 for _ in range(model.max_context_len())]
def data(self):
return self._kvcache
def drop(self, model):
model.drop_kv_cache(self._kvcache)
def update_tokens(self, tokens, pos):
end = pos + len(tokens)
max_len = len(self.tokens)
# If overflow, truncate tokens to fit
if end > max_len:
tokens = tokens[: max_len - pos]
end = max_len
self.tokens[pos:end] = tokens
from typing import List
from libinfinicore_infer import (
JiugeMeta,
JiugeWeights,
KVCache,
JiugeMetaCStruct,
JiugeWeightsCStruct,
KVCacheCStruct,
DataType,
DeviceType,
create_jiuge_model,
......@@ -11,7 +11,7 @@ from libinfinicore_infer import (
drop_kv_cache,
infer_batch,
)
from infer_task import InferTask
from infer_task import InferTask, KVCache
from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import os
......@@ -25,6 +25,7 @@ import transformers
torch.set_default_device("cpu")
class LlamaWeightsNaming:
def input_embd(self):
return "model.embed_tokens.weight"
......@@ -78,7 +79,7 @@ class LlamaWeightsNaming:
)
class JiugeMetaFromLlama(JiugeMeta):
class JiugeMetaFromLlama(JiugeMetaCStruct):
def __init__(self, config, dtype=torch.float16):
if dtype == torch.float16:
dt_ = DataType.INFINI_DTYPE_F16
......@@ -107,7 +108,7 @@ class JiugeMetaFromLlama(JiugeMeta):
self.torch_dtype_logits = dtype
class JiugeWeightsImpl(JiugeWeights):
class JiugeWeightsImpl(JiugeWeightsCStruct):
def __init__(
self,
meta,
......@@ -160,7 +161,9 @@ class JiugeWeightsImpl(JiugeWeights):
self.output_norm = self.output_norm_tensor.data_ptr()
self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat)
if not transpose_weight:
self.output_embd_tensor = self.output_embd_tensor.transpose(0, 1).contiguous()
self.output_embd_tensor = self.output_embd_tensor.transpose(
0, 1
).contiguous()
self.output_embd = self.output_embd_tensor.data_ptr()
self.attn_norm_tensors = [
......@@ -197,7 +200,12 @@ class JiugeWeightsImpl(JiugeWeights):
]
if not transpose_weight:
for i in range(nlayer):
self.qkv_tensor[i] = self.qkv_tensor[i].reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d).transpose(1, 2).contiguous()
self.qkv_tensor[i] = (
self.qkv_tensor[i]
.reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d)
.transpose(1, 2)
.contiguous()
)
self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)]
self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs)
......@@ -234,13 +242,18 @@ class JiugeWeightsImpl(JiugeWeights):
self.attn_qkv_b = None
self.attn_o_tensor = [
(
state_dict[naming.attn_o(i)]
.to(torch_dt_mat)
.reshape([d, ndev, nh // ndev * dh])
.transpose(0, 1)
.contiguous()
if transpose_weight
else state_dict[naming.attn_o(i)].transpose(0, 1).to(torch_dt_mat).contiguous()
else state_dict[naming.attn_o(i)]
.transpose(0, 1)
.to(torch_dt_mat)
.contiguous()
)
for i in range(nlayer)
]
self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)]
......@@ -269,18 +282,28 @@ class JiugeWeightsImpl(JiugeWeights):
]
if not transpose_weight:
for i in range(nlayer):
self.gate_up_tensors[i] = self.gate_up_tensors[i].reshape(ndev, 2 * di // ndev, d).transpose(1, 2).contiguous()
self.gate_up_tensors[i] = (
self.gate_up_tensors[i]
.reshape(ndev, 2 * di // ndev, d)
.transpose(1, 2)
.contiguous()
)
self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)]
self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs)
self.ffn_down_tensor = [
(
state_dict[naming.down(i)]
.to(torch_dt_mat)
.reshape([d, ndev, di // ndev])
.transpose(0, 1)
.contiguous()
if transpose_weight
else state_dict[naming.down(i)].transpose(0, 1).to(torch_dt_mat).contiguous()
else state_dict[naming.down(i)]
.transpose(0, 1)
.to(torch_dt_mat)
.contiguous()
)
for i in range(nlayer)
]
self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)]
......@@ -296,7 +319,7 @@ class JiugeBatchedTask:
token_lists = [t.tokens for t in tasks]
self.req_lens_list = [len(toks) for toks in token_lists]
self.req_pos_list = [t.pos for t in tasks]
self.kv_cache_ptrs = [t.kvcache() for t in tasks]
self.kv_cache_ptrs = [t.kvcache().data() for t in tasks]
self.temperaturas_list = [t.temperature for t in tasks]
self.topks_list = [t.topk for t in tasks]
self.topps_list = [t.topp for t in tasks]
......@@ -309,7 +332,7 @@ class JiugeBatchedTask:
self.tokens = (c_uint * self.ntok)(*flat_tokens)
self.req_lens = (c_uint * self.nreq)(*self.req_lens_list)
self.req_pos = (c_uint * self.nreq)(*self.req_pos_list)
self.kv_caches = (POINTER(KVCache) * self.nreq)(*self.kv_cache_ptrs)
self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs)
self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list)
self.topks = (c_uint * self.nreq)(*self.topks_list)
self.topps = (c_float * self.nreq)(*self.topps_list)
......@@ -346,26 +369,46 @@ class JiugeForCauslLM:
config = json.load(f)
self.config = config
eos_token_id = self.config["eos_token_id"]
self.eos_token_id = [eos_token_id] if type(eos_token_id) == int else eos_token_id
transpose_weight = device != DeviceType.DEVICE_TYPE_ASCEND # y = xW is faster than y=xW^T on Ascend
self.eos_token_id = (
[eos_token_id] if type(eos_token_id) == int else eos_token_id
)
transpose_weight = (
device != DeviceType.DEVICE_TYPE_ASCEND
) # y = xW is faster than y=xW^T on Ascend
if "llama" == config["model_type"]:
model = transformers.LlamaForCausalLM.from_pretrained(model_dir_path).cpu().half()
model = (
transformers.LlamaForCausalLM.from_pretrained(model_dir_path)
.cpu()
.half()
)
self.meta = JiugeMetaFromLlama(config)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), model.state_dict(), ndev=ndev, transpose_weight=transpose_weight
self.meta,
LlamaWeightsNaming(),
model.state_dict(),
ndev=ndev,
transpose_weight=transpose_weight,
)
elif "fm9g" == config["model_type"]:
if any(file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()):
if any(
file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()
):
state_dict = load_all_safetensors_from_dir(model_dir_path)
else:
state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu"
os.path.join(model_dir_path, "pytorch_model.bin"),
weights_only=True,
map_location="cpu",
)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight
self.meta,
LlamaWeightsNaming(),
state_dict,
ndev=ndev,
transpose_weight=transpose_weight,
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True
......@@ -374,12 +417,18 @@ class JiugeForCauslLM:
raise ValueError("Unsupported weight naming")
elif "fm9g7b" == config["model_type"]:
state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu"
os.path.join(model_dir_path, "pytorch_model.bin"),
weights_only=True,
map_location="cpu",
)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight
self.meta,
LlamaWeightsNaming(),
state_dict,
ndev=ndev,
transpose_weight=transpose_weight,
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True
......@@ -391,7 +440,11 @@ class JiugeForCauslLM:
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight
self.meta,
LlamaWeightsNaming(),
state_dict,
ndev=ndev,
transpose_weight=transpose_weight,
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path
......@@ -435,7 +488,6 @@ class JiugeForCauslLM:
return list(output)
def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0):
kv_cache = create_kv_cache(self.model_instance)
input_content = self.tokenizer.apply_chat_template(
conversation=[{"role": "user", "content": input_content}],
add_generation_prompt=True,
......@@ -443,39 +495,26 @@ class JiugeForCauslLM:
)
print(input_content, end="", flush=True)
tokens = self.tokenizer.encode(input_content)
ntok = len(tokens)
nreq = 1
output_content = ""
tokens = (c_uint * ntok)(*tokens)
req_lens = (c_uint * nreq)(*[ntok])
req_pos = (c_uint * nreq)(*[0])
kv_caches = (POINTER(KVCache) * nreq)(*[kv_cache])
ans = (c_uint * nreq)()
temperature = (c_float * nreq)(*[temperature_])
topk = (c_uint * nreq)(*[topk_])
topp = (c_float * nreq)(*[topp_])
infer_task = InferTask(
0,
tokens,
self.max_context_len(),
temperature_,
topk_,
topp_,
self.eos_token_id,
)
infer_task.bind_kvcache(KVCache(self))
steps = 0
total_time = 0
output_content = ""
for step_i in range(max_steps):
start_time = time.time()
infer_batch(
self.model_instance,
tokens,
ntok,
req_lens,
nreq,
req_pos,
kv_caches,
temperature,
topk,
topp,
ans,
)
steps += 1
output_tokens = list(ans)
output_tokens = self.batch_infer_one_round([infer_task])
end_time = time.time()
steps += 1
output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ")
......@@ -485,10 +524,7 @@ class JiugeForCauslLM:
print(output_str, end="", flush=True)
if output_tokens[0] in self.eos_token_id:
break
req_pos[0] = req_pos[0] + ntok
ntok = 1
tokens = (c_uint * ntok)(*output_tokens)
req_lens = (c_uint * nreq)(*[ntok])
infer_task.next(output_tokens[0])
if step_i > 0:
total_time += end_time - start_time
......@@ -496,8 +532,8 @@ class JiugeForCauslLM:
print("\n")
avg_time = total_time * 1000 / (steps - 1)
print(f"Time per step: {avg_time:.3f}ms")
for kv_cache in kv_caches:
drop_kv_cache(self.model_instance, kv_cache)
infer_task._kv_cache.drop(self)
return output_content, avg_time
def destroy_model_instance(self):
......
from infer_task import KVCache
import asyncio
from typing import List
class KVCachePoolItem:
def __init__(self, model):
self.kvcache = model.create_kv_cache()
self.tokens = [0 for _ in range(model.max_context_len())]
def drop(self, model):
model.drop_kv_cache(self.kvcache)
def update_tokens(self, tokens, pos):
end = pos + len(tokens)
max_len = len(self.tokens)
# If overflow, truncate tokens to fit
if end > max_len:
tokens = tokens[: max_len - pos]
end = max_len
self.tokens[pos:end] = tokens
import threading
......@@ -29,7 +9,7 @@ class KVCachePool:
def __init__(self, model, max_caches: int = 32):
self.max_caches = max_caches
self.model = model
self._available: List[KVCachePoolItem] = []
self._available: List[KVCache] = []
self.num_caches = len(self._available)
self._lock = threading.Lock()
self._not_empty = threading.Condition(self._lock)
......@@ -45,8 +25,10 @@ class KVCachePool:
if len(self._available) == 0:
if self.num_caches < self.max_caches:
self.num_caches += 1
print(f"[INFO] Task {infer_task.id} created new KVCachePoolItem")
return infer_task.bind_kvcache(KVCachePoolItem(self.model), 0)
print(
f"[INFO] Task {infer_task.id} created new KVCachePoolItem"
)
return infer_task.bind_kvcache(KVCache(self.model), 0)
else:
self._not_empty.wait()
else:
......@@ -62,8 +44,7 @@ class KVCachePool:
def release_sync(self, infer_task):
with self._not_empty:
print(f"[INFO] Task {infer_task.id} returned KVCachePoolItem to pool")
self._available.append(infer_task._kv_cache_pool_item)
infer_task._kv_cache_pool_item = None
self._available.append(infer_task.release_kvcache())
self._not_empty.notify()
async def acquire(self, infer_task):
......
......@@ -70,6 +70,17 @@ print(
f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs."
)
# A wrapper for InferTask that supports async output queue
class AsyncInferTask(InferTask):
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens):
super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens)
self.output_queue = janus.Queue()
print(f"[INFO] Create InferTask {self.id}")
def output(self, out_token):
self.next(out_token)
self.output_queue.sync_q.put(out_token)
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
......@@ -132,7 +143,7 @@ def build_task(id_, request_data, request: Request):
tokenize=False,
)
tokens = request.app.state.model.tokenizer.encode(input_content)
return InferTask(
return AsyncInferTask(
id_,
tokens,
request_data.get("max_tokens", request.app.state.model.max_context_len()),
......
......@@ -35,7 +35,7 @@ class DeviceType(ctypes.c_int):
DEVICE_TYPE_MOORE = 5
class JiugeMeta(ctypes.Structure):
class JiugeMetaCStruct(ctypes.Structure):
_fields_ = [
("dt_logits", DataType),
("nlayer", c_size_t),
......@@ -53,7 +53,7 @@ class JiugeMeta(ctypes.Structure):
# Define the JiugeWeights struct
class JiugeWeights(ctypes.Structure):
class JiugeWeightsCStruct(ctypes.Structure):
_fields_ = [
("nlayer", c_size_t),
("dt_norm", DataType),
......@@ -72,11 +72,11 @@ class JiugeWeights(ctypes.Structure):
]
class JiugeModel(ctypes.Structure):
class JiugeModelCSruct(ctypes.Structure):
pass
class KVCache(ctypes.Structure):
class KVCacheCStruct(ctypes.Structure):
pass
......@@ -85,27 +85,27 @@ def __open_library__():
os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so"
)
lib = ctypes.CDLL(lib_path)
lib.createJiugeModel.restype = POINTER(JiugeModel)
lib.createJiugeModel.restype = POINTER(JiugeModelCSruct)
lib.createJiugeModel.argtypes = [
POINTER(JiugeMeta), # JiugeMeta const *
POINTER(JiugeWeights), # JiugeWeights const *
POINTER(JiugeMetaCStruct), # JiugeMeta const *
POINTER(JiugeWeightsCStruct), # JiugeWeights const *
DeviceType, # DeviceType
c_int, # int ndev
POINTER(c_int), # int const *dev_ids
]
lib.destroyJiugeModel.argtypes = [POINTER(JiugeModel)]
lib.createKVCache.argtypes = [POINTER(JiugeModel)]
lib.createKVCache.restype = POINTER(KVCache)
lib.dropKVCache.argtypes = [POINTER(JiugeModel), POINTER(KVCache)]
lib.destroyJiugeModel.argtypes = [POINTER(JiugeModelCSruct)]
lib.createKVCache.argtypes = [POINTER(JiugeModelCSruct)]
lib.createKVCache.restype = POINTER(KVCacheCStruct)
lib.dropKVCache.argtypes = [POINTER(JiugeModelCSruct), POINTER(KVCacheCStruct)]
lib.inferBatch.restype = None
lib.inferBatch.argtypes = [
POINTER(JiugeModel), # struct JiugeModel const *
POINTER(JiugeModelCSruct), # struct JiugeModel const *
POINTER(c_uint), # unsigned int const *tokens
c_uint, # unsigned int ntok
POINTER(c_uint), # unsigned int const *req_lens
c_uint, # unsigned int nreq
POINTER(c_uint), # unsigned int const *req_pos
POINTER(POINTER(KVCache)), # struct KVCache **kv_caches
POINTER(POINTER(KVCacheCStruct)), # struct KVCache **kv_caches
POINTER(c_float), # float temperature
POINTER(c_uint), # unsigned int topk
POINTER(c_float), # float topp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment