Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
72f4d162
Unverified
Commit
72f4d162
authored
Mar 01, 2026
by
Woosuk Kwon
Committed by
GitHub
Mar 01, 2026
Browse files
[Model Runner V2] Use block table apis for capture inputs (#35671)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
5a435507
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
2 deletions
+13
-2
vllm/v1/worker/gpu/block_table.py
vllm/v1/worker/gpu/block_table.py
+11
-0
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+2
-2
No files found.
vllm/v1/worker/gpu/block_table.py
View file @
72f4d162
...
...
@@ -119,6 +119,10 @@ class BlockTables:
return
tuple
(
block_table
[:
num_reqs
]
for
block_table
in
self
.
input_block_tables
)
def
get_dummy_block_tables
(
self
,
num_reqs
:
int
)
->
tuple
[
torch
.
Tensor
,
...]:
# NOTE(woosuk): The output may be used for CUDA graph capture.
# Therefore, this method must return the persistent tensor
# with the same memory address as that used during the model's forward pass,
# rather than allocating a new tensor.
return
tuple
(
block_table
[:
num_reqs
]
for
block_table
in
self
.
input_block_tables
)
def
compute_slot_mappings
(
...
...
@@ -150,7 +154,14 @@ class BlockTables:
return
self
.
slot_mappings
[:,
:
num_tokens
]
def
get_dummy_slot_mappings
(
self
,
num_tokens
:
int
)
->
torch
.
Tensor
:
# Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
# This is because the padding logic is complex and kernels may access beyond
# the requested range.
self
.
slot_mappings
.
fill_
(
PAD_SLOT_ID
)
# NOTE(woosuk): The output may be used for CUDA graph capture.
# Therefore, this method must return the persistent tensor
# with the same memory address as that used during the model's forward pass,
# rather than allocating a new tensor.
return
self
.
slot_mappings
[:,
:
num_tokens
]
...
...
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
72f4d162
...
...
@@ -420,8 +420,8 @@ def prepare_inputs_to_capture(
input_buffers
.
dcp_local_seq_lens
[:
num_reqs
]
=
num_tokens
input_buffers
.
dcp_local_seq_lens
[
num_reqs
:]
=
0
input_block_tables
=
[
x
[:
num_reqs
]
for
x
in
block_tables
.
input
_block_tables
]
slot_mappings
=
block_tables
.
slot_mappings
[:,
:
num_tokens
]
input_block_tables
=
block_tables
.
get_dummy
_block_tables
(
num_reqs
)
slot_mappings
=
block_tables
.
get_dummy_
slot_mappings
(
num_tokens
)
slot_mappings_by_layer
=
build_slot_mappings_by_layer
(
slot_mappings
,
kv_cache_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment