Commit 21f83e91 authored by Pan Zezhong's avatar Pan Zezhong
Browse files

add kv cache pool

parent a73433ab
import asyncio
class InferTask:
def __init__(self, id, tokenizer, request):
self.id_ = id
self.finished_reason = None
messages = request.get("messages", [])
if len(messages) == 0:
self.finished_reason = "invalid request"
self.tokens = []
else:
input_content = tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
tokenize=False,
)
self.tokens = tokenizer.encode(input_content)
self.request = request
self.output_queue = asyncio.Queue()
self._kv_cache_pool_item = None
self.pos = 0
def bind_kvcache(self, kv_cache_pool_item, pos):
self._kv_cache_pool_item = kv_cache_pool_item
self.pos = pos
self.tokens = self.tokens[pos:]
def kvcache(self):
return self._kv_cache_pool_item.kvcache
......@@ -372,6 +372,9 @@ class JiugeForCauslLM:
)
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")
def max_context_len(self):
return self.meta.dctx
def create_kv_cache(self):
return create_kv_cache(self.model_instance)
......
import asyncio
from typing import List
class KVCachePoolItem:
def __init__(self, model):
self.kvcache = model.create_kv_cache()
self.tokens = [0 for _ in range(model.max_context_len())]
def drop(self, model):
model.drop_kv_cache(self.kvcache)
class KVCachePool:
def __init__(self, model, max_caches: int = 32):
self.max_caches = max_caches
self.model = model
self._available: List[KVCachePoolItem] = [KVCachePoolItem(self.model)]
self.num_caches = 1
self._lock = asyncio.Lock()
self._not_empty = asyncio.Condition(self._lock)
self._shutdown = False
async def acquire(self, infer_task):
async with self._not_empty:
while True:
if self._shutdown:
raise RuntimeError("KVCachePool is shutting down; cannot acquire new cache.")
if len(self._available) == 0:
if self.num_caches < self.max_caches:
self.num_caches += 1
return infer_task.bind_kvcache(KVCachePoolItem(self.model), 0)
else:
await self._not_empty.wait()
else:
max_match, max_match_index = self.find_most_matching_cache(
infer_task.tokens
)
kvcache = self._available.pop(max_match_index)
return infer_task.bind_kvcache(kvcache, max_match)
async def release(self, infer_task):
async with self._not_empty:
self._available.append(infer_task._kv_cache_pool_item)
self._not_empty.notify()
def find_most_matching_cache(self, tokens: List[int]):
max_match = 0
max_match_index = 0
def first_different_index(a_, b_):
for i_, (x_, y_) in enumerate(zip(a_, b_)):
if x_ != y_:
return i_
return min(len(a_), len(b_))
for i, kvcache in enumerate(self._available):
common_elements = first_different_index(tokens, kvcache.tokens)
if common_elements > max_match:
max_match = common_elements
max_match_index = i
# max match should always be less then input tokens length
return (min(max_match, len(tokens) - 1), max_match_index)
async def finalize(self):
async with self._not_empty:
self._shutdown = True
while len(self._available) < self.num_caches:
await self._not_empty.wait()
# All caches are now available
for kvcache in self._available:
if kvcache is not None:
kvcache.drop(self.model)
self._available.clear()
self.max_caches = 0
self.num_caches = 0
import asyncio
from jiuge import JiugeForCauslLM
from libinfinicore_infer import DeviceType
from infer_task import InferTask
from kvcache_pool import KVCachePool
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
......@@ -37,26 +40,26 @@ else:
sys.exit(1)
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
model = JiugeForCauslLM(model_path, device_type, ndev)
kv_cache = model.create_kv_cache()
MODEL = JiugeForCauslLM(model_path, device_type, ndev)
App = FastAPI()
def signal_handler(sig, frame):
print(f"Received signal {sig}, cleaning up...")
model.drop_kv_cache(kv_cache)
model.destroy_model_instance()
@App.on_event("startup")
async def setup():
App.state.kv_cache_pool = KVCachePool(MODEL, 1)
async def handle_shutdown():
await App.state.kv_cache_pool.finalize()
MODEL.destroy_model_instance()
sys.exit(0)
def signal_handler(sig, frame):
print(f"Received signal {sig}, cleaning up...")
asyncio.create_task(handle_shutdown())
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle docker stop / system shutdown
app = FastAPI()
# TO REMOVE: Global lock to ensure only one request is handled at a time
# Remove this after multiple requests handling is implemented
request_lock = anyio.Lock()
def chunk_json(id_, content=None, role=None, finish_reason=None):
delta = {}
......@@ -83,14 +86,18 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
async def chat_stream(id_, request_data, request: Request):
try:
await request_lock.acquire()
infer_task = InferTask(id_, MODEL.tokenizer, request_data)
await App.state.kv_cache_pool.acquire(infer_task)
chunk = json.dumps(
chunk_json(id_, content="", role="assistant"),
ensure_ascii=False,
)
yield f"{chunk}\n\n"
async for token in model.chat_stream_async(request_data, kv_cache):
async for token in MODEL.chat_stream_async(
infer_task.request,
infer_task.kvcache(),
):
if await request.is_disconnected():
print("Client disconnected. Aborting stream.")
break
......@@ -100,8 +107,7 @@ async def chat_stream(id_, request_data, request: Request):
)
yield f"{chunk}\n\n"
finally:
if request_lock.locked():
request_lock.release()
await App.state.kv_cache_pool.release(infer_task)
chunk = json.dumps(
chunk_json(id_, finish_reason="stop"),
ensure_ascii=False,
......@@ -109,18 +115,22 @@ async def chat_stream(id_, request_data, request: Request):
yield f"{chunk}\n\n"
def chat(id_, request_data):
output_text = model.chat(
request_data,
kv_cache,
async def chat(id_, request_data):
infer_task = InferTask(id_, MODEL.tokenizer, request_data)
await App.state.kv_cache_pool.acquire(infer_task)
output_text = MODEL.chat(
infer_task.request,
infer_task.kvcache(),
)
response = chunk_json(
id_, content=output_text.strip(), role="assistant", finish_reason="stop"
)
await App.state.kv_cache_pool.release(infer_task)
return JSONResponse(response)
@app.post("/chat/completions")
@App.post("/chat/completions")
async def chat_completions(request: Request):
data = await request.json()
......@@ -138,7 +148,7 @@ async def chat_completions(request: Request):
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(App, host="0.0.0.0", port=8000)
"""
curl -N -H "Content-Type: application/json" \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment