Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b046cf79
Unverified
Commit
b046cf79
authored
May 23, 2025
by
Chauncey
Committed by
GitHub
May 23, 2025
Browse files
[Feature][V1]: suupports cached_tokens in response usage (#18149)
Co-authored-by:
simon-mo
<
xmo@berkeley.edu
>
parent
54af9159
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
27 additions
and
5 deletions
+27
-5
tests/v1/core/test_scheduler_e2e.py
tests/v1/core/test_scheduler_e2e.py
+10
-1
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+4
-1
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+3
-0
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+6
-3
vllm/v1/request.py
vllm/v1/request.py
+4
-0
No files found.
tests/v1/core/test_scheduler_e2e.py
View file @
b046cf79
...
...
@@ -19,7 +19,8 @@ def model() -> LLM:
enable_prefix_caching
=
True
,
long_prefill_token_threshold
=
2
,
max_num_batched_tokens
=
6
,
max_num_seqs
=
3
)
max_num_seqs
=
3
,
block_size
=
16
)
def
test_concurrent_partial_prefill
(
model
):
...
...
@@ -27,3 +28,11 @@ def test_concurrent_partial_prefill(model):
assert
len
(
outputs
)
==
3
for
output
in
outputs
:
assert
len
(
output
.
outputs
)
==
1
def
test_prefix_cache_stats_is_recorded
(
model
):
# 17 tokens will make sure first 16 tokens are cached in a block
input_tokens
=
{
"prompt_token_ids"
:
[
101
]
*
17
}
_
=
model
.
generate
([
input_tokens
])
outputs
=
model
.
generate
([
input_tokens
])
assert
outputs
[
0
].
num_cached_tokens
==
16
vllm/v1/core/sched/scheduler.py
View file @
b046cf79
...
...
@@ -457,7 +457,9 @@ class Scheduler(SchedulerInterface):
token_budget
-=
num_new_tokens
request
.
status
=
RequestStatus
.
RUNNING
request
.
num_computed_tokens
=
num_computed_tokens
# Count the number of prifix cached tokens.
if
request
.
num_cached_tokens
<
0
:
request
.
num_cached_tokens
=
num_computed_tokens
# Encoder-related.
if
encoder_inputs_to_schedule
:
scheduled_encoder_inputs
[
request
.
request_id
]
=
(
...
...
@@ -798,6 +800,7 @@ class Scheduler(SchedulerInterface):
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
else
:
...
...
vllm/v1/engine/__init__.py
View file @
b046cf79
...
...
@@ -107,6 +107,9 @@ class EngineCoreOutput(
events
:
Optional
[
list
[
EngineCoreEvent
]]
=
None
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
# The number of tokens with prefix cache hits.
num_cached_tokens
:
int
=
0
@
property
def
finished
(
self
)
->
bool
:
return
self
.
finish_reason
is
not
None
...
...
vllm/v1/engine/output_processor.py
View file @
b046cf79
...
...
@@ -147,6 +147,7 @@ class RequestState:
finish_reason
:
Optional
[
FinishReason
],
stop_reason
:
Union
[
int
,
str
,
None
],
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
num_cached_tokens
:
int
=
0
,
)
->
Optional
[
RequestOutput
]:
finished
=
finish_reason
is
not
None
...
...
@@ -169,7 +170,7 @@ class RequestState:
return
None
return
self
.
_new_request_output
(
request_id
,
outputs
,
finished
,
kv_transfer_params
)
kv_transfer_params
,
num_cached_tokens
)
def
_new_request_output
(
self
,
...
...
@@ -177,6 +178,7 @@ class RequestState:
outputs
:
list
[
CompletionOutput
],
finished
:
bool
,
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
num_cached_tokens
:
int
=
0
,
)
->
RequestOutput
:
if
self
.
output_kind
==
RequestOutputKind
.
DELTA
:
...
...
@@ -193,6 +195,7 @@ class RequestState:
outputs
=
outputs
,
finished
=
finished
,
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
num_cached_tokens
,
)
def
_new_completion_output
(
...
...
@@ -340,7 +343,7 @@ class OutputProcessor:
finish_reason
=
engine_core_output
.
finish_reason
stop_reason
=
engine_core_output
.
stop_reason
kv_transfer_params
=
engine_core_output
.
kv_transfer_params
num_cached_tokens
=
engine_core_output
.
num_cached_tokens
req_state
.
is_prefilling
=
False
# 2) Detokenize the token ids into text and perform stop checks.
...
...
@@ -356,7 +359,7 @@ class OutputProcessor:
# 4) Create and handle RequestOutput objects.
if
request_output
:
=
req_state
.
make_request_output
(
new_token_ids
,
finish_reason
,
stop_reason
,
kv_transfer_params
):
kv_transfer_params
,
num_cached_tokens
):
if
req_state
.
queue
is
not
None
:
# AsyncLLM: put into queue for handling by generate().
req_state
.
queue
.
put
(
request_output
)
...
...
vllm/v1/request.py
View file @
b046cf79
...
...
@@ -77,6 +77,10 @@ class Request:
self
.
output_token_ids
=
ConstantList
(
self
.
_output_token_ids
)
self
.
all_token_ids
=
ConstantList
(
self
.
_all_token_ids
)
# State
# The number of tokens with prefix cache hits.
self
.
num_cached_tokens
=
-
1
@
classmethod
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
if
request
.
mm_inputs
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment