Unverified Commit 0892d37d authored by wang jiahao's avatar wang jiahao Committed by GitHub
Browse files

Merge pull request #1172 from kvcache-ai/move_create_sched

Move KV cache creation to balance_serve
parents e44c45e7 38e84190
...@@ -30,6 +30,7 @@ from ktransformers.server.balance_serve.sched_rpc import SchedulerClient ...@@ -30,6 +30,7 @@ from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
from ktransformers.server.balance_serve.settings import sched_ext from ktransformers.server.balance_serve.settings import sched_ext
from torch.multiprocessing import Queue from torch.multiprocessing import Queue
import torch.multiprocessing as mp import torch.multiprocessing as mp
from multiprocessing.synchronize import Event
from ktransformers.server.schemas.endpoints.chat import RawUsage from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.utils.multi_timer import Profiler from ktransformers.server.utils.multi_timer import Profiler
import zmq import zmq
...@@ -41,8 +42,10 @@ import threading ...@@ -41,8 +42,10 @@ import threading
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
import os import os
import pickle
import subprocess
import tempfile
import atexit
ktransformer_rules_dir = ( ktransformer_rules_dir = (
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/") os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
...@@ -99,7 +102,7 @@ class Engine: ...@@ -99,7 +102,7 @@ class Engine:
sampler: Sampler sampler: Sampler
query_manager: QueryManager query_manager: QueryManager
cache: KDeepSeekV3Cache cache: KDeepSeekV3Cache
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None): def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None):
self.args = args self.args = args
# 子进程和父进程无法共享 config 变量 # 子进程和父进程无法共享 config 变量
...@@ -115,14 +118,6 @@ class Engine: ...@@ -115,14 +118,6 @@ class Engine:
self.gen_queue = generated_token_queue self.gen_queue = generated_token_queue
print(f"Getting inference context from sched_client.")
inference_context = self.sched_client.get_inference_context_raw()
print(f"Got inference context, sending it to subscribers.")
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
self.cache.load(inference_context)
print(f"kv_cache loaded successfully.")
self.block_num = inference_context.k_cache[0].size(1)
with torch.device("meta"): with torch.device("meta"):
if config.architectures[0] == "DeepseekV3ForCausalLM": if config.architectures[0] == "DeepseekV3ForCausalLM":
self.model = KDeepseekV3ForCausalLM(config, self.cache) self.model = KDeepseekV3ForCausalLM(config, self.cache)
...@@ -165,6 +160,17 @@ class Engine: ...@@ -165,6 +160,17 @@ class Engine:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.model.eval() self.model.eval()
kvcache_event.set()
# load kvcache
print(f"Getting inference context from sched_client.")
inference_context = self.sched_client.get_inference_context_raw()
print(f"Got inference context, sending it to subscribers.")
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
self.cache.load(inference_context)
print(f"kv_cache loaded successfully.")
self.block_num = inference_context.k_cache[0].size(1)
#@TODO add config #@TODO add config
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num) self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
...@@ -240,8 +246,8 @@ class BalanceServeThreadContext(ThreadContext): ...@@ -240,8 +246,8 @@ class BalanceServeThreadContext(ThreadContext):
return local_messages return local_messages
def run_engine(args, token_queue, broadcast_endpoint, event): def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event):
engine = Engine(args, token_queue, broadcast_endpoint) engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)
if args.use_cuda_graph: if args.use_cuda_graph:
engine.model_runner.warmup() engine.model_runner.warmup()
...@@ -278,10 +284,34 @@ class BalanceServeInterface(BackendInterfaceBase): ...@@ -278,10 +284,34 @@ class BalanceServeInterface(BackendInterfaceBase):
self.streamer = TextStreamer(self.tokenizer) self.streamer = TextStreamer(self.tokenizer)
start_event = ctx.Event() start_event = ctx.Event()
kvcache_event = ctx.Event()
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event)) p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event, kvcache_event))
p.start() p.start()
processes.append(p) processes.append(p)
kvcache_event.wait()
def cleanup():
if sched_process.poll() is None:
sched_process.terminate()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
pickle.dump(args, temp_file)
temp_file_path = temp_file.name
current_file = __file__
target_file = os.path.join(os.path.dirname(current_file), "..", "..", "balance_serve", "sched_rpc.py")
target_file = os.path.normpath(target_file)
log_path = os.path.join(args.log_dir, "rpc.log")
log = open(log_path, "a")
sched_process = subprocess.Popen(
["python3", target_file, "--config", temp_file_path],
stdout=log,
stderr=log
)
print("sched_rpc started with PID:", sched_process.pid)
atexit.register(cleanup)
start_event.wait() start_event.wait()
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]: def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
......
...@@ -5,7 +5,6 @@ from fastapi.staticfiles import StaticFiles ...@@ -5,7 +5,6 @@ from fastapi.staticfiles import StaticFiles
import uvicorn.logging import uvicorn.logging
import uvicorn import uvicorn
import sys import sys
import atexit
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.args import ArgumentParser from ktransformers.server.args import ArgumentParser
...@@ -17,8 +16,7 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -17,8 +16,7 @@ from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.api import router, post_db_creation_operations from ktransformers.server.api import router, post_db_creation_operations
from ktransformers.server.utils.sql_utils import Base, SQLUtil from ktransformers.server.utils.sql_utils import Base, SQLUtil
from ktransformers.server.config.log import logger from ktransformers.server.config.log import logger
import subprocess
import tempfile
def mount_app_routes(mount_app: FastAPI): def mount_app_routes(mount_app: FastAPI):
sql_util = SQLUtil() sql_util = SQLUtil()
...@@ -108,27 +106,6 @@ def main(): ...@@ -108,27 +106,6 @@ def main():
arg_parser = ArgumentParser(cfg) arg_parser = ArgumentParser(cfg)
args = arg_parser.parse_args() args = arg_parser.parse_args()
if args.backend_type == "balance_serve":
import pickle
def cleanup():
if sched_process.poll() is None:
sched_process.terminate()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
pickle.dump(args, temp_file)
temp_file_path = temp_file.name
current_file = __file__
target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py")
target_file = os.path.normpath(target_file)
log_path = os.path.join(args.log_dir, "rpc.log")
log = open(log_path, "a")
sched_process = subprocess.Popen(
["python3", target_file, "--config", temp_file_path],
stdout=log,
stderr=log
)
print("sched_rpc started with PID:", sched_process.pid)
atexit.register(cleanup)
create_interface(config=cfg, default_args=cfg) create_interface(config=cfg, default_args=cfg)
app = create_app() app = create_app()
custom_openapi(app) custom_openapi(app)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment