Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7f280d69
Unverified
Commit
7f280d69
authored
Jul 01, 2025
by
Woosuk Kwon
Committed by
GitHub
Jul 01, 2025
Browse files
[Optimization] Cache sampled token ids in model runner (#20291)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
02cabff2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
91 additions
and
45 deletions
+91
-45
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+6
-6
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
...d/kv_transfer/kv_connector/v1/shared_storage_connector.py
+2
-2
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+2
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+13
-5
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+68
-32
No files found.
tests/v1/worker/test_gpu_model_runner.py
View file @
7f280d69
...
...
@@ -172,7 +172,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_state
.
block_ids
[
0
]).
all
()
def
test_update_states_new_request
(
model_runner
):
def
test_update_states_new_request
(
model_runner
,
dist_init
):
req_id
=
"req_0"
# new req
...
...
@@ -186,7 +186,7 @@ def test_update_states_new_request(model_runner):
assert
_is_req_state_block_table_match
(
model_runner
,
req_id
)
def
test_update_states_request_finished
(
model_runner
):
def
test_update_states_request_finished
(
model_runner
,
dist_init
):
req_id
=
"req_0"
# new req
...
...
@@ -218,7 +218,7 @@ def test_update_states_request_finished(model_runner):
assert
not
_is_req_scheduled
(
model_runner
,
req_id
)
def
test_update_states_request_resumed
(
model_runner
):
def
test_update_states_request_resumed
(
model_runner
,
dist_init
):
req_id
=
"req_0"
# new req
...
...
@@ -278,7 +278,7 @@ def test_update_states_request_resumed(model_runner):
assert
_is_req_state_block_table_match
(
model_runner
,
req_id
)
def
test_get_nans_in_logits
(
model_runner
):
def
test_get_nans_in_logits
(
model_runner
,
dist_init
):
req_ids
=
(
"req_0"
,
"req_1"
)
scheduler_output
=
_schedule_new_request
(
*
req_ids
)
...
...
@@ -326,7 +326,7 @@ def test_get_nans_in_logits(model_runner):
assert
result
==
{
'req_0'
:
2
,
'req_1'
:
0
}
def
test_update_states_no_changes
(
model_runner
):
def
test_update_states_no_changes
(
model_runner
,
dist_init
):
req_id
=
"req_0"
# new req
...
...
@@ -359,7 +359,7 @@ def test_update_states_no_changes(model_runner):
assert
_is_req_state_block_table_match
(
model_runner
,
req_id
)
def
test_update_states_request_unscheduled
(
model_runner
):
def
test_update_states_request_unscheduled
(
model_runner
,
dist_init
):
req_ids
=
(
"req_0"
,
"req_1"
)
# new reqs
...
...
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
View file @
7f280d69
...
...
@@ -307,7 +307,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
for
i
,
req_id
in
enumerate
(
cached_reqs
.
req_ids
):
num_computed_tokens
=
cached_reqs
.
num_computed_tokens
[
i
]
new_token
_id
s
=
ca
ched
_reqs
.
new_token_ids
[
i
]
num_
new_tokens
=
s
ched
uler_output
.
num_scheduled_tokens
[
req_id
]
new_block_ids
=
cached_reqs
.
new_block_ids
[
i
]
resumed_from_preemption
=
cached_reqs
.
resumed_from_preemption
[
i
]
...
...
@@ -320,7 +320,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request
=
self
.
_requests_need_load
[
req_id
]
total_tokens
=
(
len
(
new
_token
_ids
)
+
num_
computed
_tokens
)
total_tokens
=
num_computed
_token
s
+
num_
new
_tokens
token_ids
=
request
.
all_token_ids
[:
total_tokens
]
# NOTE(rob): For resumed req, new_block_ids is all
...
...
vllm/v1/core/sched/output.py
View file @
7f280d69
...
...
@@ -88,6 +88,8 @@ class CachedRequestData:
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption
:
list
[
bool
]
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty.
new_token_ids
:
list
[
list
[
int
]]
new_block_ids
:
list
[
tuple
[
list
[
int
],
...]]
num_computed_tokens
:
list
[
int
]
...
...
vllm/v1/core/sched/scheduler.py
View file @
7f280d69
...
...
@@ -55,6 +55,7 @@ class Scheduler(SchedulerInterface):
self
.
lora_config
=
vllm_config
.
lora_config
self
.
kv_cache_config
=
kv_cache_config
self
.
kv_events_config
=
vllm_config
.
kv_events_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
log_stats
=
log_stats
self
.
structured_output_manager
=
structured_output_manager
...
...
@@ -87,7 +88,7 @@ class Scheduler(SchedulerInterface):
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
self
.
kv_events_config
,
vllm_config
.
parallel_config
.
data_parallel_rank
,
self
.
parallel_config
.
data_parallel_rank
,
)
num_gpu_blocks
=
self
.
cache_config
.
num_gpu_blocks
...
...
@@ -159,6 +160,7 @@ class Scheduler(SchedulerInterface):
log_stats
=
self
.
log_stats
,
enable_kv_cache_events
=
self
.
enable_kv_cache_events
,
)
self
.
use_pp
=
self
.
parallel_config
.
pipeline_parallel_size
>
1
def
schedule
(
self
)
->
SchedulerOutput
:
# NOTE(woosuk) on the scheduling algorithm:
...
...
@@ -214,7 +216,7 @@ class Scheduler(SchedulerInterface):
# This is necessary when using spec decoding.
num_new_tokens
=
min
(
num_new_tokens
,
self
.
max_model_len
-
request
.
num_computed_tokens
)
self
.
max_model_len
-
1
-
request
.
num_computed_tokens
)
# Schedule encoder inputs.
encoder_inputs_to_schedule
=
None
...
...
@@ -624,9 +626,15 @@ class Scheduler(SchedulerInterface):
req_ids
.
append
(
req_id
)
num_tokens
=
(
num_scheduled_tokens
[
req_id
]
-
len
(
spec_decode_tokens
.
get
(
req_id
,
())))
token_ids
=
req
.
all_token_ids
[
req
.
num_computed_tokens
:
req
.
num_computed_tokens
+
num_tokens
]
new_token_ids
.
append
(
token_ids
)
if
self
.
use_pp
:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# will cache them.
token_ids
=
req
.
all_token_ids
[
req
.
num_computed_tokens
:
req
.
num_computed_tokens
+
num_tokens
]
new_token_ids
.
append
(
token_ids
)
new_block_ids
.
append
(
req_to_new_block_ids
[
req_id
])
num_computed_tokens
.
append
(
req
.
num_computed_tokens
)
# Because resumed_reqs is usually empty, it is more efficient to do
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
7f280d69
...
...
@@ -470,26 +470,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add
.
append
(
req_id
)
# Update the states of the running/resumed requests.
is_last_rank
=
get_pp_group
().
is_last_rank
req_data
=
scheduler_output
.
scheduled_cached_reqs
for
i
,
req_id
in
enumerate
(
req_data
.
req_ids
):
req_state
=
self
.
requests
[
req_id
]
num_computed_tokens
=
req_data
.
num_computed_tokens
[
i
]
new_token_ids
=
req_data
.
new_token_ids
[
i
]
new_block_ids
=
req_data
.
new_block_ids
[
i
]
resumed_from_preemption
=
req_data
.
resumed_from_preemption
[
i
]
# Update the cached states.
req_state
.
num_computed_tokens
=
num_computed_tokens
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens
=
(
num_computed_tokens
+
len
(
new_token_ids
)
-
req_state
.
num_tokens
)
if
num_new_tokens
==
1
:
# Avoid slicing list in most common case.
req_state
.
output_token_ids
.
append
(
new_token_ids
[
-
1
])
elif
num_new_tokens
>
0
:
req_state
.
output_token_ids
.
extend
(
new_token_ids
[
-
num_new_tokens
:])
if
not
is_last_rank
:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker.
new_token_ids
=
req_data
.
new_token_ids
[
i
]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens
=
(
num_computed_tokens
+
len
(
new_token_ids
)
-
req_state
.
num_tokens
)
if
num_new_tokens
==
1
:
# Avoid slicing list in most common case.
req_state
.
output_token_ids
.
append
(
new_token_ids
[
-
1
])
elif
num_new_tokens
>
0
:
req_state
.
output_token_ids
.
extend
(
new_token_ids
[
-
num_new_tokens
:])
# Update the block IDs.
if
not
resumed_from_preemption
:
# Append the new blocks to the existing block IDs.
...
...
@@ -513,22 +520,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
input_batch
.
num_computed_tokens_cpu
[
req_index
]
=
(
num_computed_tokens
)
self
.
input_batch
.
block_table
.
append_row
(
new_block_ids
,
req_index
)
# Add new_token_ids to token_ids_cpu.
start_token_index
=
num_computed_tokens
end_token_index
=
num_computed_tokens
+
len
(
new_token_ids
)
self
.
input_batch
.
token_ids_cpu
[
req_index
,
start_token_index
:
end_token_index
]
=
new_token_ids
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
=
end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
())
if
spec_token_ids
:
start_index
=
end_token_index
end_token_index
+=
len
(
spec_token_ids
)
# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if
not
is_last_rank
:
# Add new_token_ids to token_ids_cpu.
start_token_index
=
num_computed_tokens
end_token_index
=
num_computed_tokens
+
len
(
new_token_ids
)
self
.
input_batch
.
token_ids_cpu
[
req_index
,
start_index
:
end_token_index
]
=
spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
self
.
input_batch
.
num_tokens
[
req_index
]
=
end_token_index
req_index
,
start_token_index
:
end_token_index
]
=
new_token_ids
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
=
end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
()))
if
spec_token_ids
:
start_index
=
end_token_index
end_token_index
+=
len
(
spec_token_ids
)
self
.
input_batch
.
token_ids_cpu
[
req_index
,
start_index
:
end_token_index
]
=
spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec tokens.
self
.
input_batch
.
num_tokens
[
req_index
]
=
end_token_index
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
...
...
@@ -1509,6 +1524,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
if
not
sampled_ids
:
continue
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
end_idx
=
start_idx
+
len
(
sampled_ids
)
assert
end_idx
<=
self
.
max_model_len
,
(
"Sampled token IDs exceed the max model length. "
f
"Total number of tokens:
{
end_idx
}
> max_model_len: "
f
"
{
self
.
max_model_len
}
"
)
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:
end_idx
]
=
sampled_ids
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
req_id
=
self
.
input_batch
.
req_ids
[
req_idx
]
req_state
=
self
.
requests
[
req_id
]
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
...
...
@@ -1730,17 +1769,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids
.
append
([])
continue
# Add sampled_token_ids to token_ids_cpu.
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
i
]
end_idx
=
start_idx
+
num_sampled_ids
if
end_idx
>=
self
.
max_model_len
:
num_tokens
=
self
.
input_batch
.
num_tokens_no_spec
[
i
]
if
num_tokens
>=
self
.
max_model_len
:
# Skip requests that have already reached the max model length.
draft_token_ids
.
append
([])
continue
self
.
input_batch
.
token_ids_cpu
[
i
,
start_idx
:
end_idx
]
=
sampled_ids
drafter_output
=
self
.
drafter
.
propose
(
self
.
input_batch
.
token_ids_cpu
[
i
,
:
end_idx
])
self
.
input_batch
.
token_ids_cpu
[
i
,
:
num_tokens
])
if
drafter_output
is
None
or
len
(
drafter_output
)
==
0
:
draft_token_ids
.
append
([])
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment