Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
81561f8e
"app/vscode:/vscode.git/clone" did not exist on "52ea4d4bb2812f7ed7611e868fa429bf29970a1c"
Unverified
Commit
81561f8e
authored
Jan 26, 2024
by
Liangsheng Yin
Committed by
GitHub
Jan 25, 2024
Browse files
Flush Cache API (#103)
parent
3a581e99
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
48 additions
and
64 deletions
+48
-64
python/sglang/flush_cache.py
python/sglang/flush_cache.py
+0
-60
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+5
-0
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+23
-1
python/sglang/srt/managers/router/radix_cache.py
python/sglang/srt/managers/router/radix_cache.py
+6
-3
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+5
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+9
-0
No files found.
python/sglang/flush_cache.py
deleted
100644 → 0
View file @
3a581e99
"""Flush cache in the backend by sending random requests."""
import
argparse
import
random
import
string
import
time
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
select_sglang_backend
,
)
import
sglang
as
sgl
@
sgl
.
function
def
flush_radix_cache
(
s
,
prompt
):
s
+=
prompt
+
sgl
.
gen
(
"flush"
,
max_tokens
=
1
,
stop
=
"END"
)
def
main
(
args
,
max_total_tokens
,
context_length
,
print_flag
):
backend
=
select_sglang_backend
(
args
)
flush_length
=
int
(
context_length
*
0.8
)
batch_size
=
int
(
max_total_tokens
/
flush_length
)
prompt_length
=
flush_length
*
2
prompts
=
[
" "
.
join
(
random
.
choices
(
string
.
ascii_letters
,
k
=
int
(
prompt_length
)))
for
_
in
range
(
batch_size
)
]
arguments
=
[{
"prompt"
:
prompts
[
i
]}
for
i
in
range
(
batch_size
)]
start_time
=
time
.
time
()
flush_radix_cache
.
run_batch
(
arguments
,
temperature
=
0
,
backend
=
backend
,
num_threads
=
1
)
end_time
=
time
.
time
()
if
print_flag
:
print
(
f
"Flush length:
{
flush_length
}
\n
"
,
f
"Prompt length:
{
prompt_length
}
\n
"
,
f
"Total Prompt letters:
{
batch_size
*
prompt_length
}
\n
"
,
f
"Flush radix cache latency:
{
end_time
-
start_time
:.
3
f
}
"
,
sep
=
""
,
)
# to prevent the backend still running
time
.
sleep
(
1
)
def
run_flush
(
args
,
max_total_tokens
=
20000
,
context_length
=
1024
,
print_flag
=
False
):
main
(
args
,
max_total_tokens
,
context_length
,
print_flag
=
print_flag
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--max-total-tokens"
,
type
=
int
,
default
=
20000
)
parser
.
add_argument
(
"--context-length"
,
type
=
int
,
default
=
1024
)
args
=
add_common_sglang_args_and_parse
(
parser
)
random
.
seed
(
0
)
main
(
args
,
args
.
max_total_tokens
,
args
.
context_length
,
print_flag
=
True
)
python/sglang/srt/managers/io_struct.py
View file @
81561f8e
...
...
@@ -87,3 +87,8 @@ class BatchStrOut:
output_str
:
List
[
str
]
meta_info
:
List
[
Dict
]
finished
:
List
[
bool
]
@
dataclass
class
FlushCacheReq
:
pass
python/sglang/srt/managers/router/model_rpc.py
View file @
81561f8e
...
...
@@ -15,7 +15,11 @@ from rpyc.utils.server import ThreadedServer
from
sglang.srt.constrained.fast_forward
import
FastForwardCache
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
,
TokenizedGenerateReqInput
from
sglang.srt.managers.io_struct
import
(
BatchTokenIDOut
,
TokenizedGenerateReqInput
,
FlushCacheReq
,
)
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
,
Req
from
sglang.srt.managers.router.model_runner
import
ModelRunner
from
sglang.srt.managers.router.radix_cache
import
RadixCache
...
...
@@ -127,6 +131,22 @@ class ModelRpcServer(rpyc.Service):
self
.
min_new_token_ratio
=
min
(
0.2
*
server_args
.
schedule_conservativeness
,
1.0
)
self
.
new_token_ratio_step
=
(
0.0001
,
0.05
)
# (down, up)
def
flush_cache
(
self
):
if
len
(
self
.
forward_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
self
.
tree_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
else
:
warnings
.
warn
(
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
def
exposed_step
(
self
,
recv_reqs
):
if
self
.
tp_size
!=
1
:
recv_reqs
=
obtain
(
recv_reqs
)
...
...
@@ -136,6 +156,8 @@ class ModelRpcServer(rpyc.Service):
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
handle_generate_request
(
recv_req
)
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
...
...
python/sglang/srt/managers/router/radix_cache.py
View file @
81561f8e
...
...
@@ -30,14 +30,17 @@ def match(key, seq):
class
RadixCache
:
def
__init__
(
self
,
disable
=
False
):
self
.
reset
()
self
.
disable
=
disable
##### Public API #####
def
reset
(
self
):
self
.
root_node
=
TreeNode
()
self
.
root_node
.
value
=
[]
self
.
root_node
.
ref_counter
=
1
self
.
evictable_size_
=
0
self
.
disable
=
disable
##### Public API #####
def
match_prefix
(
self
,
key
):
if
self
.
disable
:
return
[],
self
.
root_node
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
81561f8e
...
...
@@ -20,6 +20,7 @@ from sglang.srt.managers.io_struct import (
BatchStrOut
,
GenerateReqInput
,
TokenizedGenerateReqInput
,
FlushCacheReq
,
)
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
...
...
@@ -228,6 +229,10 @@ class TokenizerManager:
yield
output_list
async
def
flush_cache
(
self
):
flush_cache_req
=
FlushCacheReq
()
self
.
send_to_router
.
send_pyobj
(
flush_cache_req
)
async
def
create_handle_loop
(
self
):
self
.
to_create_loop
=
False
loop
=
asyncio
.
get_event_loop
()
...
...
python/sglang/srt/server.py
View file @
81561f8e
...
...
@@ -71,6 +71,15 @@ async def get_model_info():
return
result
@
app
.
get
(
"/flush_cache"
)
async
def
flush_cache
():
await
tokenizer_manager
.
flush_cache
()
return
Response
(
content
=
"Cache flushed.
\n
Please check backend logs for more details. (When there are running or waiting requests, the operation will not be performed.)
\n
"
,
status_code
=
200
,
)
async
def
stream_generator
(
obj
):
async
for
out
in
tokenizer_manager
.
generate_request
(
obj
):
yield
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment