Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f088a831
Unverified
Commit
f088a831
authored
Mar 10, 2026
by
Woosuk Kwon
Committed by
GitHub
Mar 10, 2026
Browse files
[Model Runner V2] Use unpadded num_tokens for PW CUDA graph attn metadata (#36626)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
f83b933b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
3 deletions
+14
-3
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+1
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+1
-0
vllm/v1/worker/gpu/model_states/default.py
vllm/v1/worker/gpu/model_states/default.py
+10
-3
vllm/v1/worker/gpu/model_states/interface.py
vllm/v1/worker/gpu/model_states/interface.py
+2
-0
No files found.
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
f088a831
...
@@ -384,6 +384,7 @@ def prepare_inputs_to_capture(
...
@@ -384,6 +384,7 @@ def prepare_inputs_to_capture(
attn_metadata
=
model_state
.
prepare_attn
(
attn_metadata
=
model_state
.
prepare_attn
(
input_batch
,
input_batch
,
CUDAGraphMode
.
NONE
,
input_block_tables
,
input_block_tables
,
slot_mappings
,
slot_mappings
,
attn_groups
,
attn_groups
,
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
f088a831
...
@@ -936,6 +936,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -936,6 +936,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
block_tables
is
not
None
assert
block_tables
is
not
None
attn_metadata
=
self
.
model_state
.
prepare_attn
(
attn_metadata
=
self
.
model_state
.
prepare_attn
(
input_batch
,
input_batch
,
batch_desc
.
cg_mode
,
block_tables
,
block_tables
,
slot_mappings
,
slot_mappings
,
self
.
attn_groups
,
self
.
attn_groups
,
...
...
vllm/v1/worker/gpu/model_states/default.py
View file @
f088a831
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
...
@@ -140,14 +141,20 @@ class DefaultModelState(ModelState):
...
@@ -140,14 +141,20 @@ class DefaultModelState(ModelState):
def
prepare_attn
(
def
prepare_attn
(
self
,
self
,
input_batch
:
InputBatch
,
input_batch
:
InputBatch
,
cudagraph_mode
:
CUDAGraphMode
,
block_tables
:
tuple
[
torch
.
Tensor
,
...],
block_tables
:
tuple
[
torch
.
Tensor
,
...],
slot_mappings
:
torch
.
Tensor
,
slot_mappings
:
torch
.
Tensor
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
)
->
dict
[
str
,
Any
]:
)
->
dict
[
str
,
Any
]:
if
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
# Use padded sizes - padding is handled by model_runner.prepare_attn.
# Use padded sizes - padding is handled by model_runner.prepare_attn.
num_reqs
=
input_batch
.
num_reqs_after_padding
num_reqs
=
input_batch
.
num_reqs_after_padding
num_tokens
=
input_batch
.
num_tokens_after_padding
num_tokens
=
input_batch
.
num_tokens_after_padding
else
:
# For piecewise cudagraphs and eager, use unpadded sizes.
num_reqs
=
input_batch
.
num_reqs
num_tokens
=
input_batch
.
num_tokens
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
)
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
)
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
()
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
()
attn_metadata
=
build_attn_metadata
(
attn_metadata
=
build_attn_metadata
(
...
...
vllm/v1/worker/gpu/model_states/interface.py
View file @
f088a831
...
@@ -7,6 +7,7 @@ import torch
...
@@ -7,6 +7,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
...
@@ -59,6 +60,7 @@ class ModelState(ABC):
...
@@ -59,6 +60,7 @@ class ModelState(ABC):
def
prepare_attn
(
def
prepare_attn
(
self
,
self
,
input_batch
:
InputBatch
,
input_batch
:
InputBatch
,
cudagraph_mode
:
CUDAGraphMode
,
block_tables
:
tuple
[
torch
.
Tensor
,
...],
block_tables
:
tuple
[
torch
.
Tensor
,
...],
slot_mappings
:
torch
.
Tensor
,
slot_mappings
:
torch
.
Tensor
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
attn_groups
:
list
[
list
[
AttentionGroup
]],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment