Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
81561f8e
Unverified
Commit
81561f8e
authored
Jan 26, 2024
by
Liangsheng Yin
Committed by
GitHub
Jan 25, 2024
Browse files
Flush Cache API (#103)
parent
3a581e99
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
48 additions
and
64 deletions
+48
-64
python/sglang/flush_cache.py
python/sglang/flush_cache.py
+0
-60
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+5
-0
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+23
-1
python/sglang/srt/managers/router/radix_cache.py
python/sglang/srt/managers/router/radix_cache.py
+6
-3
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+5
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+9
-0
No files found.
python/sglang/flush_cache.py
deleted
100644 → 0
View file @
3a581e99
"""Flush cache in the backend by sending random requests."""
import
argparse
import
random
import
string
import
time
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
select_sglang_backend
,
)
import
sglang
as
sgl
@
sgl
.
function
def
flush_radix_cache
(
s
,
prompt
):
s
+=
prompt
+
sgl
.
gen
(
"flush"
,
max_tokens
=
1
,
stop
=
"END"
)
def
main
(
args
,
max_total_tokens
,
context_length
,
print_flag
):
backend
=
select_sglang_backend
(
args
)
flush_length
=
int
(
context_length
*
0.8
)
batch_size
=
int
(
max_total_tokens
/
flush_length
)
prompt_length
=
flush_length
*
2
prompts
=
[
" "
.
join
(
random
.
choices
(
string
.
ascii_letters
,
k
=
int
(
prompt_length
)))
for
_
in
range
(
batch_size
)
]
arguments
=
[{
"prompt"
:
prompts
[
i
]}
for
i
in
range
(
batch_size
)]
start_time
=
time
.
time
()
flush_radix_cache
.
run_batch
(
arguments
,
temperature
=
0
,
backend
=
backend
,
num_threads
=
1
)
end_time
=
time
.
time
()
if
print_flag
:
print
(
f
"Flush length:
{
flush_length
}
\n
"
,
f
"Prompt length:
{
prompt_length
}
\n
"
,
f
"Total Prompt letters:
{
batch_size
*
prompt_length
}
\n
"
,
f
"Flush radix cache latency:
{
end_time
-
start_time
:.
3
f
}
"
,
sep
=
""
,
)
# to prevent the backend still running
time
.
sleep
(
1
)
def
run_flush
(
args
,
max_total_tokens
=
20000
,
context_length
=
1024
,
print_flag
=
False
):
main
(
args
,
max_total_tokens
,
context_length
,
print_flag
=
print_flag
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--max-total-tokens"
,
type
=
int
,
default
=
20000
)
parser
.
add_argument
(
"--context-length"
,
type
=
int
,
default
=
1024
)
args
=
add_common_sglang_args_and_parse
(
parser
)
random
.
seed
(
0
)
main
(
args
,
args
.
max_total_tokens
,
args
.
context_length
,
print_flag
=
True
)
python/sglang/srt/managers/io_struct.py
View file @
81561f8e
...
...
@@ -87,3 +87,8 @@ class BatchStrOut:
output_str
:
List
[
str
]
meta_info
:
List
[
Dict
]
finished
:
List
[
bool
]
@
dataclass
class
FlushCacheReq
:
pass
python/sglang/srt/managers/router/model_rpc.py
View file @
81561f8e
...
...
@@ -15,7 +15,11 @@ from rpyc.utils.server import ThreadedServer
from
sglang.srt.constrained.fast_forward
import
FastForwardCache
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
,
TokenizedGenerateReqInput
from
sglang.srt.managers.io_struct
import
(
BatchTokenIDOut
,
TokenizedGenerateReqInput
,
FlushCacheReq
,
)
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
,
Req
from
sglang.srt.managers.router.model_runner
import
ModelRunner
from
sglang.srt.managers.router.radix_cache
import
RadixCache
...
...
@@ -127,6 +131,22 @@ class ModelRpcServer(rpyc.Service):
self
.
min_new_token_ratio
=
min
(
0.2
*
server_args
.
schedule_conservativeness
,
1.0
)
self
.
new_token_ratio_step
=
(
0.0001
,
0.05
)
# (down, up)
def
flush_cache
(
self
):
if
len
(
self
.
forward_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
self
.
tree_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
else
:
warnings
.
warn
(
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
def
exposed_step
(
self
,
recv_reqs
):
if
self
.
tp_size
!=
1
:
recv_reqs
=
obtain
(
recv_reqs
)
...
...
@@ -136,6 +156,8 @@ class ModelRpcServer(rpyc.Service):
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
handle_generate_request
(
recv_req
)
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
...
...
python/sglang/srt/managers/router/radix_cache.py
View file @
81561f8e
...
...
@@ -30,14 +30,17 @@ def match(key, seq):
class
RadixCache
:
def
__init__
(
self
,
disable
=
False
):
self
.
reset
()
self
.
disable
=
disable
##### Public API #####
def
reset
(
self
):
self
.
root_node
=
TreeNode
()
self
.
root_node
.
value
=
[]
self
.
root_node
.
ref_counter
=
1
self
.
evictable_size_
=
0
self
.
disable
=
disable
##### Public API #####
def
match_prefix
(
self
,
key
):
if
self
.
disable
:
return
[],
self
.
root_node
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
81561f8e
...
...
@@ -20,6 +20,7 @@ from sglang.srt.managers.io_struct import (
BatchStrOut
,
GenerateReqInput
,
TokenizedGenerateReqInput
,
FlushCacheReq
,
)
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
...
...
@@ -228,6 +229,10 @@ class TokenizerManager:
yield
output_list
async
def
flush_cache
(
self
):
flush_cache_req
=
FlushCacheReq
()
self
.
send_to_router
.
send_pyobj
(
flush_cache_req
)
async
def
create_handle_loop
(
self
):
self
.
to_create_loop
=
False
loop
=
asyncio
.
get_event_loop
()
...
...
python/sglang/srt/server.py
View file @
81561f8e
...
...
@@ -71,6 +71,15 @@ async def get_model_info():
return
result
@
app
.
get
(
"/flush_cache"
)
async
def
flush_cache
():
await
tokenizer_manager
.
flush_cache
()
return
Response
(
content
=
"Cache flushed.
\n
Please check backend logs for more details. (When there are running or waiting requests, the operation will not be performed.)
\n
"
,
status_code
=
200
,
)
async
def
stream_generator
(
obj
):
async
for
out
in
tokenizer_manager
.
generate_request
(
obj
):
yield
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment