Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ecb8bad2
Unverified
Commit
ecb8bad2
authored
Oct 16, 2024
by
havetc
Committed by
GitHub
Oct 16, 2024
Browse files
Returning a per request metric for number of cached_tokens read (#1599)
parent
dbec2f18
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
245 additions
and
3 deletions
+245
-3
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+10
-0
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+1
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-0
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+14
-3
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+2
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
test/srt/test_cache_report.py
test/srt/test_cache_report.py
+211
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
ecb8bad2
...
...
@@ -196,6 +196,9 @@ class Req:
# this does not include the jump forward tokens.
self
.
completion_tokens_wo_jump_forward
=
0
# The number of cached tokens, that were already cached in the KV store
self
.
cached_tokens
=
0
# For vision inputs
self
.
image_inputs
:
Optional
[
ImageInputs
]
=
None
...
...
@@ -499,6 +502,13 @@ class ScheduleBatch:
pt
=
0
for
i
,
req
in
enumerate
(
reqs
):
already_computed
=
(
req
.
extend_logprob_start_len
+
1
+
req
.
cached_tokens
if
req
.
extend_logprob_start_len
>
0
else
0
)
req
.
cached_tokens
+=
len
(
req
.
prefix_indices
)
-
already_computed
req
.
req_pool_idx
=
req_pool_indices
[
i
]
pre_len
,
seq_len
=
len
(
req
.
prefix_indices
),
len
(
req
.
fill_ids
)
seq_lens
.
append
(
seq_len
)
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
ecb8bad2
...
...
@@ -51,6 +51,7 @@ class SchedulePolicy:
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
r
.
adjust_max_prefix_ids
()
)
prefix_computed
=
True
if
self
.
policy
==
"lpm"
:
...
...
python/sglang/srt/managers/scheduler.py
View file @
ecb8bad2
...
...
@@ -978,6 +978,7 @@ class Scheduler:
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"cached_tokens"
:
req
.
cached_tokens
,
"finish_reason"
:
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
is
not
None
...
...
python/sglang/srt/openai_api/adapter.py
View file @
ecb8bad2
...
...
@@ -302,7 +302,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
if
end_point
==
"/v1/chat/completions"
:
responses
=
v1_chat_generate_response
(
request
,
ret
,
to_file
=
True
)
responses
=
v1_chat_generate_response
(
request
,
ret
,
to_file
=
True
,
cache_report
=
tokenizer_manager
.
server_args
.
enable_cache_report
,
)
else
:
responses
=
v1_generate_response
(
request
,
ret
,
tokenizer_manager
,
to_file
=
True
...
...
@@ -970,7 +975,7 @@ def v1_chat_generate_request(
return
adapted_request
,
all_requests
def
v1_chat_generate_response
(
request
,
ret
,
to_file
=
False
):
def
v1_chat_generate_response
(
request
,
ret
,
to_file
=
False
,
cache_report
=
False
):
choices
=
[]
for
idx
,
ret_item
in
enumerate
(
ret
):
...
...
@@ -1067,6 +1072,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
ret
[
i
][
"meta_info"
][
"prompt_tokens"
]
for
i
in
range
(
0
,
len
(
ret
),
request
.
n
)
)
completion_tokens
=
sum
(
item
[
"meta_info"
][
"completion_tokens"
]
for
item
in
ret
)
cached_tokens
=
sum
(
item
[
"meta_info"
].
get
(
"cached_tokens"
,
0
)
for
item
in
ret
)
response
=
ChatCompletionResponse
(
id
=
ret
[
0
][
"meta_info"
][
"id"
],
model
=
request
.
model
,
...
...
@@ -1075,6 +1081,9 @@ def v1_chat_generate_response(request, ret, to_file=False):
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
prompt_tokens_details
=
(
{
"cached_tokens"
:
cached_tokens
}
if
cache_report
else
None
),
),
)
return
response
...
...
@@ -1240,7 +1249,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
response
=
v1_chat_generate_response
(
request
,
ret
)
response
=
v1_chat_generate_response
(
request
,
ret
,
cache_report
=
tokenizer_manager
.
server_args
.
enable_cache_report
)
return
response
...
...
python/sglang/srt/openai_api/protocol.py
View file @
ecb8bad2
...
...
@@ -76,6 +76,8 @@ class UsageInfo(BaseModel):
prompt_tokens
:
int
=
0
total_tokens
:
int
=
0
completion_tokens
:
Optional
[
int
]
=
0
# only used to return cached tokens when --enable-cache-report is set
prompt_tokens_details
:
Optional
[
Dict
[
str
,
int
]]
=
None
class
StreamOptions
(
BaseModel
):
...
...
python/sglang/srt/server_args.py
View file @
ecb8bad2
...
...
@@ -73,6 +73,7 @@ class ServerArgs:
# Other
api_key
:
Optional
[
str
]
=
None
file_storage_pth
:
str
=
"SGLang_storage"
enable_cache_report
:
bool
=
False
# Data parallelism
dp_size
:
int
=
1
...
...
@@ -410,6 +411,11 @@ class ServerArgs:
default
=
ServerArgs
.
file_storage_pth
,
help
=
"The path of the file storage in backend."
,
)
parser
.
add_argument
(
"--enable-cache-report"
,
action
=
"store_true"
,
help
=
"Return number of cached tokens in usage.prompt_tokens_details for each openai request."
,
)
# Data parallelism
parser
.
add_argument
(
...
...
test/srt/test_cache_report.py
0 → 100644
View file @
ecb8bad2
import
asyncio
import
json
import
unittest
import
openai
import
requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestCacheReport
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
min_cached
=
5
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
other_args
=
[
"--chunked-prefill-size=40"
,
"--enable-cache-report"
,
],
)
cls
.
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
f
"
{
cls
.
base_url
}
/v1"
)
cls
.
aclient
=
openai
.
AsyncClient
(
api_key
=
"EMPTY"
,
base_url
=
f
"
{
cls
.
base_url
}
/v1"
)
usage
=
cls
.
run_openai
(
cls
,
"1"
).
usage
# we can assume that our request is of size 1, plus the total template size
# ideally we would like to know the begin size / end size of the template to be more precise
total_template_size
=
usage
.
prompt_tokens
-
1
print
(
f
"template size:
{
total_template_size
}
"
)
usage2
=
cls
.
run_openai
(
cls
,
"2"
).
usage
assert
usage2
.
prompt_tokens_details
.
cached_tokens
<=
total_template_size
cls
.
min_cached
=
max
(
usage2
.
prompt_tokens_details
.
cached_tokens
,
total_template_size
-
usage2
.
prompt_tokens_details
.
cached_tokens
,
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
run_decode
(
self
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
n
=
1
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
# we use an uncommon start to minimise the chance that the cache is hit by chance
json
=
{
"text"
:
"_ The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
if
n
==
1
else
0.5
,
"max_new_tokens"
:
128
,
"n"
:
n
,
"stop_token_ids"
:
[
119690
],
},
"stream"
:
False
,
"return_logprob"
:
return_logprob
,
"top_logprobs_num"
:
top_logprobs_num
,
"logprob_start_len"
:
0
,
},
)
return
response
def
run_openai
(
self
,
message
):
response
=
self
.
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
messages
=
[
# {"role": "system", "content": "You are a helpful AI assistant"},
{
"role"
:
"user"
,
"content"
:
message
},
],
temperature
=
0
,
max_tokens
=
100
,
)
return
response
async
def
run_openai_async
(
self
,
message
):
response
=
await
self
.
aclient
.
chat
.
completions
.
create
(
model
=
self
.
model
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
message
},
],
temperature
=
0
,
max_tokens
=
100
,
)
return
response
def
cache_report_openai
(
self
,
message
):
response
=
self
.
run_openai
(
message
)
print
(
f
"openai first request cached_tokens:
{
int
(
response
.
usage
.
prompt_tokens_details
.
cached_tokens
)
}
"
)
first_cached_tokens
=
int
(
response
.
usage
.
prompt_tokens_details
.
cached_tokens
)
# assert int(response.usage.cached_tokens) == 0
assert
first_cached_tokens
<
self
.
min_cached
response
=
self
.
run_openai
(
message
)
cached_tokens
=
int
(
response
.
usage
.
prompt_tokens_details
.
cached_tokens
)
print
(
f
"openai second request cached_tokens:
{
cached_tokens
}
"
)
assert
cached_tokens
>
0
assert
cached_tokens
==
int
(
response
.
usage
.
prompt_tokens
)
-
1
return
first_cached_tokens
async
def
cache_report_openai_async
(
self
,
message
):
response
=
await
self
.
run_openai_async
(
message
)
cached_tokens
=
int
(
response
.
usage
.
prompt_tokens_details
.
cached_tokens
)
prompt_tokens
=
int
(
response
.
usage
.
prompt_tokens
)
return
cached_tokens
,
prompt_tokens
def
test_generate
(
self
):
print
(
"="
*
100
)
response
=
self
.
run_decode
()
# print(response.json())
cached_tokens
=
int
(
response
.
json
()[
"meta_info"
][
"cached_tokens"
])
print
(
f
"sglang first request cached_tokens:
{
cached_tokens
}
"
)
print
(
f
"sglang first request prompt_tokens:
{
int
(
response
.
json
()[
'meta_info'
][
'prompt_tokens'
])
}
"
)
# can't assure to be 0: depends on the initialisation request / if a template is used with the model
assert
cached_tokens
<
self
.
min_cached
response
=
self
.
run_decode
()
cached_tokens
=
int
(
response
.
json
()[
"meta_info"
][
"cached_tokens"
])
print
(
f
"sglang second request cached_tokens:
{
cached_tokens
}
"
)
print
(
f
"sglang second request prompt_tokens:
{
int
(
response
.
json
()[
'meta_info'
][
'prompt_tokens'
])
}
"
)
assert
cached_tokens
==
int
(
response
.
json
()[
"meta_info"
][
"prompt_tokens"
])
-
1
def
test_cache_split_prefill_openai
(
self
):
print
(
"="
*
100
)
self
.
cache_report_openai
(
"€ This is a very long and unique text that should not be already cached, the twist is"
" that it should be longer than the chunked-prefill-size, so it should be split among"
" several prefill requests. Still, it shouldn't be cached"
)
def
test_cache_report_openai
(
self
):
print
(
"="
*
100
)
# warm up the cache, for the template
self
.
run_openai
(
"Introduce the capital of France."
)
first_cached_tokens_1
=
self
.
run_openai
(
"How many sparrow do you need to lift a coconut?"
).
usage
.
prompt_tokens_details
.
cached_tokens
usage_2
=
self
.
run_openai
(
"* sing something about cats"
).
usage
first_cached_tokens_2
=
usage_2
.
prompt_tokens_details
.
cached_tokens
# first request may not have 0 cached tokens, but if they only have the template in common they
# should be the same once the cache is warmed up
assert
first_cached_tokens_1
==
first_cached_tokens_2
resp
=
self
.
run_openai
(
"* sing something about cats and dogs"
)
print
(
resp
.
usage
)
resp
=
self
.
run_openai
(
"* sing something about cats, please"
)
print
(
resp
.
usage
)
assert
(
resp
.
usage
.
prompt_tokens_details
.
cached_tokens
>=
usage_2
.
prompt_tokens
-
self
.
min_cached
)
def
test_cache_report_openai_async
(
self
):
print
(
"="
*
100
)
async
def
run_test
():
task0
=
asyncio
.
create_task
(
self
.
cache_report_openai_async
(
"first request, to start the inference and let the next two request be started in the same batch"
)
)
await
asyncio
.
sleep
(
0.05
)
# to force the first request to be started first
task1
=
asyncio
.
create_task
(
self
.
cache_report_openai_async
(
"> can the same batch parallel request use the cache?"
)
)
task2
=
asyncio
.
create_task
(
self
.
cache_report_openai_async
(
"> can the same batch parallel request use the cache?"
)
)
result0
,
result1
,
result2
=
await
asyncio
.
gather
(
task0
,
task1
,
task2
)
cached_tokens0
,
prompt_tokens0
=
result0
cached_tokens1
,
prompt_tokens1
=
result1
cached_tokens2
,
prompt_tokens2
=
result2
print
(
f
"Async request 0 - Cached tokens:
{
cached_tokens0
}
, Prompt tokens:
{
prompt_tokens0
}
"
)
print
(
f
"Async request 1 - Cached tokens:
{
cached_tokens1
}
, Prompt tokens:
{
prompt_tokens1
}
"
)
print
(
f
"Async request 2 - Cached tokens:
{
cached_tokens2
}
, Prompt tokens:
{
prompt_tokens2
}
"
)
# Assert that no requests used the cache (becausefirst is alone, and the next two are in the same batch)
# If a new optimisation limiting starting request with same prefix at the same time was added
# to maximise the cache hit, this would not be true
assert
cached_tokens1
==
cached_tokens2
==
cached_tokens0
asyncio
.
run
(
run_test
())
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment