Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
38e84190
Commit
38e84190
authored
Apr 18, 2025
by
qiyuxinlin
Browse files
Move KV cache creation to balance_serve
parent
8770b6d5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
38 deletions
+45
-38
ktransformers/server/backend/interfaces/balance_serve.py
ktransformers/server/backend/interfaces/balance_serve.py
+44
-14
ktransformers/server/main.py
ktransformers/server/main.py
+1
-24
No files found.
ktransformers/server/backend/interfaces/balance_serve.py
View file @
38e84190
...
...
@@ -30,6 +30,7 @@ from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
from
ktransformers.server.balance_serve.settings
import
sched_ext
from
torch.multiprocessing
import
Queue
import
torch.multiprocessing
as
mp
from
multiprocessing.synchronize
import
Event
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.utils.multi_timer
import
Profiler
import
zmq
...
...
@@ -41,8 +42,10 @@ import threading
from
contextlib
import
asynccontextmanager
from
fastapi
import
FastAPI
,
Request
import
os
import
pickle
import
subprocess
import
tempfile
import
atexit
ktransformer_rules_dir
=
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
".."
,
".."
,
".."
,
"./optimize/optimize_rules/"
)
...
...
@@ -99,7 +102,7 @@ class Engine:
sampler
:
Sampler
query_manager
:
QueryManager
cache
:
KDeepSeekV3Cache
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
,
generated_token_queue
:
Queue
=
None
,
broadcast_endpoint
:
str
=
None
):
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
,
generated_token_queue
:
Queue
=
None
,
broadcast_endpoint
:
str
=
None
,
kvcache_event
:
Event
=
None
):
self
.
args
=
args
# 子进程和父进程无法共享 config 变量
...
...
@@ -115,14 +118,6 @@ class Engine:
self
.
gen_queue
=
generated_token_queue
print
(
f
"Getting inference context from sched_client."
)
inference_context
=
self
.
sched_client
.
get_inference_context_raw
()
print
(
f
"Got inference context, sending it to subscribers."
)
inference_context
=
self
.
sched_client
.
rebuild_inferece_context
(
inference_context
)
self
.
cache
.
load
(
inference_context
)
print
(
f
"kv_cache loaded successfully."
)
self
.
block_num
=
inference_context
.
k_cache
[
0
].
size
(
1
)
with
torch
.
device
(
"meta"
):
if
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
self
.
model
=
KDeepseekV3ForCausalLM
(
config
,
self
.
cache
)
...
...
@@ -165,6 +160,17 @@ class Engine:
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
model
.
eval
()
kvcache_event
.
set
()
# load kvcache
print
(
f
"Getting inference context from sched_client."
)
inference_context
=
self
.
sched_client
.
get_inference_context_raw
()
print
(
f
"Got inference context, sending it to subscribers."
)
inference_context
=
self
.
sched_client
.
rebuild_inferece_context
(
inference_context
)
self
.
cache
.
load
(
inference_context
)
print
(
f
"kv_cache loaded successfully."
)
self
.
block_num
=
inference_context
.
k_cache
[
0
].
size
(
1
)
#@TODO add config
self
.
model
.
init_wrapper
(
self
.
args
.
use_cuda_graph
,
self
.
device
,
args
.
max_batch_size
,
self
.
block_num
)
...
...
@@ -240,8 +246,8 @@ class BalanceServeThreadContext(ThreadContext):
return
local_messages
def
run_engine
(
args
,
token_queue
,
broadcast_endpoint
,
event
):
engine
=
Engine
(
args
,
token_queue
,
broadcast_endpoint
)
def
run_engine
(
args
,
token_queue
,
broadcast_endpoint
,
event
,
kvcache_
event
):
engine
=
Engine
(
args
,
token_queue
,
broadcast_endpoint
,
kvcache_event
)
if
args
.
use_cuda_graph
:
engine
.
model_runner
.
warmup
()
...
...
@@ -278,10 +284,34 @@ class BalanceServeInterface(BackendInterfaceBase):
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
start_event
=
ctx
.
Event
()
kvcache_event
=
ctx
.
Event
()
p
=
ctx
.
Process
(
target
=
run_engine
,
args
=
(
self
.
args
,
self
.
token_queue
,
self
.
broadcast_endpoint
,
start_event
))
p
=
ctx
.
Process
(
target
=
run_engine
,
args
=
(
self
.
args
,
self
.
token_queue
,
self
.
broadcast_endpoint
,
start_event
,
kvcache_event
))
p
.
start
()
processes
.
append
(
p
)
kvcache_event
.
wait
()
def
cleanup
():
if
sched_process
.
poll
()
is
None
:
sched_process
.
terminate
()
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
temp_file
:
pickle
.
dump
(
args
,
temp_file
)
temp_file_path
=
temp_file
.
name
current_file
=
__file__
target_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
current_file
),
".."
,
".."
,
"balance_serve"
,
"sched_rpc.py"
)
target_file
=
os
.
path
.
normpath
(
target_file
)
log_path
=
os
.
path
.
join
(
args
.
log_dir
,
"rpc.log"
)
log
=
open
(
log_path
,
"a"
)
sched_process
=
subprocess
.
Popen
(
[
"python3"
,
target_file
,
"--config"
,
temp_file_path
],
stdout
=
log
,
stderr
=
log
)
print
(
"sched_rpc started with PID:"
,
sched_process
.
pid
)
atexit
.
register
(
cleanup
)
start_event
.
wait
()
def
get_sampling_params
(
self
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
)
->
tuple
[
float
,
float
]:
...
...
ktransformers/server/main.py
View file @
38e84190
...
...
@@ -5,7 +5,6 @@ from fastapi.staticfiles import StaticFiles
import
uvicorn.logging
import
uvicorn
import
sys
import
atexit
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)))
from
fastapi.middleware.cors
import
CORSMiddleware
from
ktransformers.server.args
import
ArgumentParser
...
...
@@ -17,8 +16,7 @@ from fastapi.middleware.cors import CORSMiddleware
from
ktransformers.server.api
import
router
,
post_db_creation_operations
from
ktransformers.server.utils.sql_utils
import
Base
,
SQLUtil
from
ktransformers.server.config.log
import
logger
import
subprocess
import
tempfile
def
mount_app_routes
(
mount_app
:
FastAPI
):
sql_util
=
SQLUtil
()
...
...
@@ -108,27 +106,6 @@ def main():
arg_parser
=
ArgumentParser
(
cfg
)
args
=
arg_parser
.
parse_args
()
if
args
.
backend_type
==
"balance_serve"
:
import
pickle
def
cleanup
():
if
sched_process
.
poll
()
is
None
:
sched_process
.
terminate
()
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
temp_file
:
pickle
.
dump
(
args
,
temp_file
)
temp_file_path
=
temp_file
.
name
current_file
=
__file__
target_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
current_file
),
"balance_serve"
,
"sched_rpc.py"
)
target_file
=
os
.
path
.
normpath
(
target_file
)
log_path
=
os
.
path
.
join
(
args
.
log_dir
,
"rpc.log"
)
log
=
open
(
log_path
,
"a"
)
sched_process
=
subprocess
.
Popen
(
[
"python3"
,
target_file
,
"--config"
,
temp_file_path
],
stdout
=
log
,
stderr
=
log
)
print
(
"sched_rpc started with PID:"
,
sched_process
.
pid
)
atexit
.
register
(
cleanup
)
create_interface
(
config
=
cfg
,
default_args
=
cfg
)
app
=
create_app
()
custom_openapi
(
app
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment