Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
0892d37d
Unverified
Commit
0892d37d
authored
Apr 18, 2025
by
wang jiahao
Committed by
GitHub
Apr 18, 2025
Browse files
Merge pull request #1172 from kvcache-ai/move_create_sched
Move KV cache creation to balance_serve
parents
e44c45e7
38e84190
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
38 deletions
+45
-38
ktransformers/server/backend/interfaces/balance_serve.py
ktransformers/server/backend/interfaces/balance_serve.py
+44
-14
ktransformers/server/main.py
ktransformers/server/main.py
+1
-24
No files found.
ktransformers/server/backend/interfaces/balance_serve.py
View file @
0892d37d
...
...
@@ -30,6 +30,7 @@ from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
from
ktransformers.server.balance_serve.settings
import
sched_ext
from
torch.multiprocessing
import
Queue
import
torch.multiprocessing
as
mp
from
multiprocessing.synchronize
import
Event
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.utils.multi_timer
import
Profiler
import
zmq
...
...
@@ -41,8 +42,10 @@ import threading
from
contextlib
import
asynccontextmanager
from
fastapi
import
FastAPI
,
Request
import
os
import
pickle
import
subprocess
import
tempfile
import
atexit
ktransformer_rules_dir
=
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
".."
,
".."
,
".."
,
"./optimize/optimize_rules/"
)
...
...
@@ -99,7 +102,7 @@ class Engine:
sampler
:
Sampler
query_manager
:
QueryManager
cache
:
KDeepSeekV3Cache
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
,
generated_token_queue
:
Queue
=
None
,
broadcast_endpoint
:
str
=
None
):
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
,
generated_token_queue
:
Queue
=
None
,
broadcast_endpoint
:
str
=
None
,
kvcache_event
:
Event
=
None
):
self
.
args
=
args
# 子进程和父进程无法共享 config 变量
...
...
@@ -115,14 +118,6 @@ class Engine:
self
.
gen_queue
=
generated_token_queue
print
(
f
"Getting inference context from sched_client."
)
inference_context
=
self
.
sched_client
.
get_inference_context_raw
()
print
(
f
"Got inference context, sending it to subscribers."
)
inference_context
=
self
.
sched_client
.
rebuild_inferece_context
(
inference_context
)
self
.
cache
.
load
(
inference_context
)
print
(
f
"kv_cache loaded successfully."
)
self
.
block_num
=
inference_context
.
k_cache
[
0
].
size
(
1
)
with
torch
.
device
(
"meta"
):
if
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
self
.
model
=
KDeepseekV3ForCausalLM
(
config
,
self
.
cache
)
...
...
@@ -165,6 +160,17 @@ class Engine:
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
model
.
eval
()
kvcache_event
.
set
()
# load kvcache
print
(
f
"Getting inference context from sched_client."
)
inference_context
=
self
.
sched_client
.
get_inference_context_raw
()
print
(
f
"Got inference context, sending it to subscribers."
)
inference_context
=
self
.
sched_client
.
rebuild_inferece_context
(
inference_context
)
self
.
cache
.
load
(
inference_context
)
print
(
f
"kv_cache loaded successfully."
)
self
.
block_num
=
inference_context
.
k_cache
[
0
].
size
(
1
)
#@TODO add config
self
.
model
.
init_wrapper
(
self
.
args
.
use_cuda_graph
,
self
.
device
,
args
.
max_batch_size
,
self
.
block_num
)
...
...
@@ -240,8 +246,8 @@ class BalanceServeThreadContext(ThreadContext):
return
local_messages
def
run_engine
(
args
,
token_queue
,
broadcast_endpoint
,
event
):
engine
=
Engine
(
args
,
token_queue
,
broadcast_endpoint
)
def
run_engine
(
args
,
token_queue
,
broadcast_endpoint
,
event
,
kvcache_
event
):
engine
=
Engine
(
args
,
token_queue
,
broadcast_endpoint
,
kvcache_event
)
if
args
.
use_cuda_graph
:
engine
.
model_runner
.
warmup
()
...
...
@@ -278,10 +284,34 @@ class BalanceServeInterface(BackendInterfaceBase):
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
start_event
=
ctx
.
Event
()
kvcache_event
=
ctx
.
Event
()
p
=
ctx
.
Process
(
target
=
run_engine
,
args
=
(
self
.
args
,
self
.
token_queue
,
self
.
broadcast_endpoint
,
start_event
))
p
=
ctx
.
Process
(
target
=
run_engine
,
args
=
(
self
.
args
,
self
.
token_queue
,
self
.
broadcast_endpoint
,
start_event
,
kvcache_event
))
p
.
start
()
processes
.
append
(
p
)
kvcache_event
.
wait
()
def
cleanup
():
if
sched_process
.
poll
()
is
None
:
sched_process
.
terminate
()
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
temp_file
:
pickle
.
dump
(
args
,
temp_file
)
temp_file_path
=
temp_file
.
name
current_file
=
__file__
target_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
current_file
),
".."
,
".."
,
"balance_serve"
,
"sched_rpc.py"
)
target_file
=
os
.
path
.
normpath
(
target_file
)
log_path
=
os
.
path
.
join
(
args
.
log_dir
,
"rpc.log"
)
log
=
open
(
log_path
,
"a"
)
sched_process
=
subprocess
.
Popen
(
[
"python3"
,
target_file
,
"--config"
,
temp_file_path
],
stdout
=
log
,
stderr
=
log
)
print
(
"sched_rpc started with PID:"
,
sched_process
.
pid
)
atexit
.
register
(
cleanup
)
start_event
.
wait
()
def
get_sampling_params
(
self
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
)
->
tuple
[
float
,
float
]:
...
...
ktransformers/server/main.py
View file @
0892d37d
...
...
@@ -5,7 +5,6 @@ from fastapi.staticfiles import StaticFiles
import
uvicorn.logging
import
uvicorn
import
sys
import
atexit
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)))
from
fastapi.middleware.cors
import
CORSMiddleware
from
ktransformers.server.args
import
ArgumentParser
...
...
@@ -17,8 +16,7 @@ from fastapi.middleware.cors import CORSMiddleware
from
ktransformers.server.api
import
router
,
post_db_creation_operations
from
ktransformers.server.utils.sql_utils
import
Base
,
SQLUtil
from
ktransformers.server.config.log
import
logger
import
subprocess
import
tempfile
def
mount_app_routes
(
mount_app
:
FastAPI
):
sql_util
=
SQLUtil
()
...
...
@@ -108,27 +106,6 @@ def main():
arg_parser
=
ArgumentParser
(
cfg
)
args
=
arg_parser
.
parse_args
()
if
args
.
backend_type
==
"balance_serve"
:
import
pickle
def
cleanup
():
if
sched_process
.
poll
()
is
None
:
sched_process
.
terminate
()
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
temp_file
:
pickle
.
dump
(
args
,
temp_file
)
temp_file_path
=
temp_file
.
name
current_file
=
__file__
target_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
current_file
),
"balance_serve"
,
"sched_rpc.py"
)
target_file
=
os
.
path
.
normpath
(
target_file
)
log_path
=
os
.
path
.
join
(
args
.
log_dir
,
"rpc.log"
)
log
=
open
(
log_path
,
"a"
)
sched_process
=
subprocess
.
Popen
(
[
"python3"
,
target_file
,
"--config"
,
temp_file_path
],
stdout
=
log
,
stderr
=
log
)
print
(
"sched_rpc started with PID:"
,
sched_process
.
pid
)
atexit
.
register
(
cleanup
)
create_interface
(
config
=
cfg
,
default_args
=
cfg
)
app
=
create_app
()
custom_openapi
(
app
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment