Commit 21f83e91 authored by Pan Zezhong's avatar Pan Zezhong
Browse files

add kv cache pool

parent a73433ab
import asyncio
class InferTask:
def __init__(self, id, tokenizer, request):
self.id_ = id
self.finished_reason = None
messages = request.get("messages", [])
if len(messages) == 0:
self.finished_reason = "invalid request"
self.tokens = []
else:
input_content = tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
tokenize=False,
)
self.tokens = tokenizer.encode(input_content)
self.request = request
self.output_queue = asyncio.Queue()
self._kv_cache_pool_item = None
self.pos = 0
def bind_kvcache(self, kv_cache_pool_item, pos):
self._kv_cache_pool_item = kv_cache_pool_item
self.pos = pos
self.tokens = self.tokens[pos:]
def kvcache(self):
return self._kv_cache_pool_item.kvcache
...@@ -372,6 +372,9 @@ class JiugeForCauslLM: ...@@ -372,6 +372,9 @@ class JiugeForCauslLM:
) )
load_end_time = time.time() load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s") print(f"Time used: {load_end_time - load_start_time:.3f}s")
def max_context_len(self):
return self.meta.dctx
def create_kv_cache(self): def create_kv_cache(self):
return create_kv_cache(self.model_instance) return create_kv_cache(self.model_instance)
......
import asyncio
from typing import List
class KVCachePoolItem:
def __init__(self, model):
self.kvcache = model.create_kv_cache()
self.tokens = [0 for _ in range(model.max_context_len())]
def drop(self, model):
model.drop_kv_cache(self.kvcache)
class KVCachePool:
def __init__(self, model, max_caches: int = 32):
self.max_caches = max_caches
self.model = model
self._available: List[KVCachePoolItem] = [KVCachePoolItem(self.model)]
self.num_caches = 1
self._lock = asyncio.Lock()
self._not_empty = asyncio.Condition(self._lock)
self._shutdown = False
async def acquire(self, infer_task):
async with self._not_empty:
while True:
if self._shutdown:
raise RuntimeError("KVCachePool is shutting down; cannot acquire new cache.")
if len(self._available) == 0:
if self.num_caches < self.max_caches:
self.num_caches += 1
return infer_task.bind_kvcache(KVCachePoolItem(self.model), 0)
else:
await self._not_empty.wait()
else:
max_match, max_match_index = self.find_most_matching_cache(
infer_task.tokens
)
kvcache = self._available.pop(max_match_index)
return infer_task.bind_kvcache(kvcache, max_match)
async def release(self, infer_task):
async with self._not_empty:
self._available.append(infer_task._kv_cache_pool_item)
self._not_empty.notify()
def find_most_matching_cache(self, tokens: List[int]):
max_match = 0
max_match_index = 0
def first_different_index(a_, b_):
for i_, (x_, y_) in enumerate(zip(a_, b_)):
if x_ != y_:
return i_
return min(len(a_), len(b_))
for i, kvcache in enumerate(self._available):
common_elements = first_different_index(tokens, kvcache.tokens)
if common_elements > max_match:
max_match = common_elements
max_match_index = i
# max match should always be less then input tokens length
return (min(max_match, len(tokens) - 1), max_match_index)
async def finalize(self):
async with self._not_empty:
self._shutdown = True
while len(self._available) < self.num_caches:
await self._not_empty.wait()
# All caches are now available
for kvcache in self._available:
if kvcache is not None:
kvcache.drop(self.model)
self._available.clear()
self.max_caches = 0
self.num_caches = 0
import asyncio
from jiuge import JiugeForCauslLM from jiuge import JiugeForCauslLM
from libinfinicore_infer import DeviceType from libinfinicore_infer import DeviceType
from infer_task import InferTask
from kvcache_pool import KVCachePool
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse
...@@ -37,26 +40,26 @@ else: ...@@ -37,26 +40,26 @@ else:
sys.exit(1) sys.exit(1)
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
model = JiugeForCauslLM(model_path, device_type, ndev) MODEL = JiugeForCauslLM(model_path, device_type, ndev)
kv_cache = model.create_kv_cache()
App = FastAPI()
def signal_handler(sig, frame): @App.on_event("startup")
print(f"Received signal {sig}, cleaning up...") async def setup():
model.drop_kv_cache(kv_cache) App.state.kv_cache_pool = KVCachePool(MODEL, 1)
model.destroy_model_instance()
async def handle_shutdown():
await App.state.kv_cache_pool.finalize()
MODEL.destroy_model_instance()
sys.exit(0) sys.exit(0)
def signal_handler(sig, frame):
print(f"Received signal {sig}, cleaning up...")
asyncio.create_task(handle_shutdown())
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle docker stop / system shutdown signal.signal(signal.SIGTERM, signal_handler) # Handle docker stop / system shutdown
app = FastAPI()
# TO REMOVE: Global lock to ensure only one request is handled at a time
# Remove this after multiple requests handling is implemented
request_lock = anyio.Lock()
def chunk_json(id_, content=None, role=None, finish_reason=None): def chunk_json(id_, content=None, role=None, finish_reason=None):
delta = {} delta = {}
...@@ -83,14 +86,18 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): ...@@ -83,14 +86,18 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
async def chat_stream(id_, request_data, request: Request): async def chat_stream(id_, request_data, request: Request):
try: try:
await request_lock.acquire() infer_task = InferTask(id_, MODEL.tokenizer, request_data)
await App.state.kv_cache_pool.acquire(infer_task)
chunk = json.dumps( chunk = json.dumps(
chunk_json(id_, content="", role="assistant"), chunk_json(id_, content="", role="assistant"),
ensure_ascii=False, ensure_ascii=False,
) )
yield f"{chunk}\n\n" yield f"{chunk}\n\n"
async for token in model.chat_stream_async(request_data, kv_cache): async for token in MODEL.chat_stream_async(
infer_task.request,
infer_task.kvcache(),
):
if await request.is_disconnected(): if await request.is_disconnected():
print("Client disconnected. Aborting stream.") print("Client disconnected. Aborting stream.")
break break
...@@ -100,8 +107,7 @@ async def chat_stream(id_, request_data, request: Request): ...@@ -100,8 +107,7 @@ async def chat_stream(id_, request_data, request: Request):
) )
yield f"{chunk}\n\n" yield f"{chunk}\n\n"
finally: finally:
if request_lock.locked(): await App.state.kv_cache_pool.release(infer_task)
request_lock.release()
chunk = json.dumps( chunk = json.dumps(
chunk_json(id_, finish_reason="stop"), chunk_json(id_, finish_reason="stop"),
ensure_ascii=False, ensure_ascii=False,
...@@ -109,18 +115,22 @@ async def chat_stream(id_, request_data, request: Request): ...@@ -109,18 +115,22 @@ async def chat_stream(id_, request_data, request: Request):
yield f"{chunk}\n\n" yield f"{chunk}\n\n"
def chat(id_, request_data): async def chat(id_, request_data):
output_text = model.chat( infer_task = InferTask(id_, MODEL.tokenizer, request_data)
request_data, await App.state.kv_cache_pool.acquire(infer_task)
kv_cache,
output_text = MODEL.chat(
infer_task.request,
infer_task.kvcache(),
) )
response = chunk_json( response = chunk_json(
id_, content=output_text.strip(), role="assistant", finish_reason="stop" id_, content=output_text.strip(), role="assistant", finish_reason="stop"
) )
await App.state.kv_cache_pool.release(infer_task)
return JSONResponse(response) return JSONResponse(response)
@app.post("/chat/completions") @App.post("/chat/completions")
async def chat_completions(request: Request): async def chat_completions(request: Request):
data = await request.json() data = await request.json()
...@@ -138,7 +148,7 @@ async def chat_completions(request: Request): ...@@ -138,7 +148,7 @@ async def chat_completions(request: Request):
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) uvicorn.run(App, host="0.0.0.0", port=8000)
""" """
curl -N -H "Content-Type: application/json" \ curl -N -H "Content-Type: application/json" \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment